What?
Using data-driven methods to classify reactions in different categories.
Why?
Categorically sorting (new) reactions can help with better documentation and developing a broader understanding of mechanisms possibles in the reactions.
How?
A chemical reaction is descirbed using a three-level reaction ontology based on the hierarchy proposed by Carey, Laffan, Thomson and Williams in 2006.
In this scheme, every reaction is grouped using 3 layers of information: superclass >> class >> type
Fo example: Suzuki reaction is as follows:
"3 Carbon-Carbon bond formation" (Superclass)
|- "3.1. Suzuki coupling" (Class)
|- 3.1.1 Bromo OR 3.1.2 Chloro OR 3.1.3 Iodo Suzuki Coupling (Type)
|- "3.5 Palladium-catalyzed C-C bond formation" (Class)
|- 3.5.3 Negishi coupling (Type)
Researchers at NextMove software were among the first groups to scrap US Patent literature for chemical reactions and use the categories defined above to systematically classify the reactions.
Another important step in this process is the atom-atom mapping of the chemical reactions. While not a crucial step (as newer algorithms can perform this task without explicit atom-mapping) it is an important pre-processing standardization operation.
Atom-atom mapping helps to understand which reactant atom becomes which product atom during the reaction. From this information it is possible to identify reaction centers and sets of bonds made and broken during the reaction.
This is also useful in distinguishing reactants and products.
By convention:
Reactant: Contribute one more more atoms to the product
Reagents (solvent, catalyst): Do not contribute any atom to the product(s)
Relevant papers in this field can be found here
Using the Schneider et. al. paper for reference - https://pubs.acs.org/doi/10.1021/ci5006614
import os
import pandas as pd
import numpy as np
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
from IPython.display import Image
IPythonConsole.ipython_useSVG=True
try:
import cPickle as pickle
except:
import pickle
def display_rxn(rxn_smarts):
rxn = AllChem.ReactionFromSmarts(rxn_smarts,useSmiles=True)
d2d = Draw.MolDraw2DCairo(800,200)
d2d.DrawReaction(rxn)
png = d2d.GetDrawingText()
return Image(png)
Chem.WrapLogs()
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
# High DPI rendering for mac
%config InlineBackend.figure_format = 'retina'
# Plot matplotlib plots with white background:
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
data_dir = 'DATA/Schneider_etal_ChemReactionClassification/data'
with open(os.path.join(data_dir, 'reactionTypes_training_test_set_patent_data.pkl'), 'rb') as f:
reaction_types = pickle.load(f)
# reaction classification data
with open(os.path.join(data_dir, 'names_rTypes_classes_superclasses_training_test_set_patent_data.pkl'), 'rb') as f:
names_rTypes = pickle.load(f)
len(reaction_types)
names_rTypes is a super set of all possible reaction there are
names_rTypes
import gzip
infile = gzip.open( os.path.join(data_dir, 'training_test_set_patent_data.pkl.gz'), 'rb' )
rxn_data_list = []
lineNo = 0
while True:
lineNo+=1
try:
smi,lbl,klass = pickle.load(infile)
except EOFError:
break
rxn_data_list.append([smi,lbl,klass])
if lineNo%10000 == 0:
print("Done "+str(lineNo))
len(rxn_data_list)
column_names = ['SMILES', 'Patent No', 'Rxn Class']
df_rxn = pd.DataFrame(rxn_data_list, columns=column_names)
df_rxn
df_rxn.dtypes
df_rxn['Rxn Class'].value_counts()
df_rxn.iloc[42069]
df_rxn.SMILES[42069]
display_rxn(df_rxn.SMILES[42069])
Generate Chemical Entries object in Rdkit from the RXN SMILES
%%time
# Convert Smiles strings to reaction objects - this takes the most time and might be helpful if parallelized
from rdkit.Chem import rdChemReactions # Main reaction analysis class
df_rxn['rxn_obj'] = df_rxn['SMILES'].apply(rdChemReactions.ReactionFromSmarts)
df_rxn['rxn_obj'][42069]
temp_rxn = df_rxn['rxn_obj'][42069]
type(temp_rxn)
Fingerprints in RDkit
More information here: https://www.rdkit.org/UGM/2012/Landrum_RDKit_UGM.Fingerprints.Final.pptx.pdf
Base reaction class in RDKit reaction class now moved to a new class name: http://rdkit.org/docs/source/rdkit.Chem.rdChemReactions.html
Here I am using Reaction Difference FPs for converting to FPs - another option is to use the Transformation FPs
Fingerprint Type | Meaning |
---|---|
Difference FPs | Take difference of structural FPs of reactant and product |
Structural FPs | Concatenate the FPs of reactant and product in 1 vector |
Another option:
- Adding in agent during the fingerprint generation -- weighting its importance
- Appending the agent after the FP formation
AllChem.ReactionFingerprintParams()
Chem.rdChemReactions.ReactionFingerprintParams()
rdChemReactions.CreateDifferenceFingerprintForReaction(temp_rxn)
## This is taken from the paper SI
def create_agent_feature_FP(rxn):
rxn.RemoveUnmappedReactantTemplates()
agent_feature_Fp = [0.0]*9
for nra in range(rxn.GetNumAgentTemplates()):
mol = rxn.GetAgentTemplate(nra)
mol.UpdatePropertyCache(strict=False)
Chem.GetSSSR(mol)
try:
ri = mol.GetRingInfo()
agent_feature_Fp[0] += Descriptors.MolWt(mol)
agent_feature_Fp[1] += mol.GetNumAtoms()
agent_feature_Fp[2] += ri.NumRings()
agent_feature_Fp[3] += Descriptors.MolLogP(mol)
agent_feature_Fp[4] += Descriptors.NumRadicalElectrons(mol)
agent_feature_Fp[5] += Descriptors.TPSA(mol)
agent_feature_Fp[6] += Descriptors.NumHeteroatoms(mol)
agent_feature_Fp[7] += Descriptors.NumHAcceptors(mol)
agent_feature_Fp[8] += Descriptors.NumHDonors(mol)
except:
continue
return agent_feature_Fp
def create_agent_morgan2_FP(rxn):
rxn.RemoveUnmappedReactantTemplates()
morgan2 = None
for nra in range(rxn.GetNumAgentTemplates()):
mol = rxn.GetAgentTemplate(nra)
mol.UpdatePropertyCache(strict=False)
Chem.GetSSSR(mol)
try:
mg2 = AllChem.GetMorganFingerprint(mol,radius=2)
if morgan2 is None and mg2 is not None:
morgan2 = mg2
elif mg2 is not None:
morgan2 += mg2
except:
print("Cannot build agent Fp\n")
if morgan2 is None:
morgan2 = DataStructs.UIntSparseIntVect(2048)
return morgan2
# Include agents in the fingerprint as either a reactant or product
## Inputs are reaction object, fp_type object, int, int
# Create dictionary of all Molecular Fingerprinting types with names
fptype_dict = {"AtomPairFP": AllChem.FingerprintType.AtomPairFP,
"MorganFP": AllChem.FingerprintType.MorganFP,
"TopologicalFP": AllChem.FingerprintType.TopologicalTorsion,
"PatternFP": AllChem.FingerprintType.PatternFP,
"RDKitFP": AllChem.FingerprintType.RDKitFP}
# Construct a difference fingerprint for a ChemicalReaction by subtracting the reactant fingerprint from the product fingerprint
def diff_fpgen(rxn, fptype_dict = fptype_dict, fp_type = 'MorganFP', include_agent=True, agent_weight=1, nonagent_weight=10):
params = rdChemReactions.ReactionFingerprintParams()
params.fptype = fptype_dict[fp_type]
params.includeAgents = include_agent
if include_agent == True:
'''
If including agent then how is it weighted?
'''
params.agentWeight = agent_weight
params.nonAgentWeight = nonagent_weight
fp = rdChemReactions.CreateDifferenceFingerprintForReaction(rxn,params)
return fp
from rdkit import DataStructs
def fingerprint2Numpy(FPs):
fp_np = np.zeros((1,))
DataStructs.ConvertToNumpyArray(FPs, fp_np)
return fp_np
# convert a hashed SparseIntvect into a numpy float vector
def hashedFPToNPfloat(fp,fpsz=2048):
nfp = np.zeros((fpsz,), float)
for idx,v in fp.GetNonzeroElements().items():
nfp[idx]+=float(v)
return nfp
df_rxn.sample(2)
%%time
df_rxn['FP_Morgan_wo_agents'] = df_rxn['rxn_obj'].apply(diff_fpgen)
Adding in agents is giving me problem right now - debug it eventually
%%time
X_FPs = np.array( [hashedFPToNPfloat(x) for x in df_rxn['FP_Morgan_wo_agents']] )
Y_class = np.array( df_rxn['Rxn Class'] )
rtypes = sorted(list(reaction_types))
rtype_int = [int(''.join(entry.split('.'))) for entry in rtypes]
len(set(rtype_int))
Note on multi-class classification:
https://scikit-learn.org/stable/modules/multiclass.html#multiclass-classification
LabelBinarizer is not needed if you are using an estimator that already supports multiclass data.
https://scikit-learn.org/stable/modules/preprocessing_targets.html#preprocessing-targets
Create one hot encoding -- does it help to create OHE now? Not sure but doing it here as a first pass.
leave_as_is = True
if leave_as_is == True:
Y_target = Y_class
else:
Y_target = Y_class_OHE
from sklearn.model_selection import StratifiedShuffleSplit
stratSplit = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
for train_idx, test_idx in stratSplit.split(X_FPs, Y_target):
X_train = X_FPs[train_idx]
Y_train = Y_target[train_idx]
X_test = X_FPs[test_idx]
Y_test = Y_target[test_idx]
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(max_depth=200,n_estimators=250,random_state=42)
model.fit(X_train, Y_train)
Y_test_predict = model.predict(X_test)
from sklearn.metrics import confusion_matrix, classification_report
report_real = classification_report(Y_test, Y_test_predict, output_dict=True)
cmat_real = confusion_matrix(Y_test,Y_test_predict)
sum(cmat_real,0)
from sklearn import metrics
# evaluate model calculating recall, precision and F-score, return the confusion matrix
def evaluateModel(_model, _testFPs, _test_rxn_labels, _sorted_rxn_label, _names_rTypes):
preds = _model.predict(_testFPs)
#pred_class = [ int(np.argmax(pred_entry)) for pred_entry in preds ]
#testReactionTypes_class = [ int(np.argmax(test_entry))for test_entry in testReactionTypes ]
cmat = metrics.confusion_matrix(_test_rxn_labels, preds)
colCounts = sum(cmat,0)
rowCounts = sum(cmat,1)
print('%2s %7s %7s %7s %s'%("ID","recall","prec","F-score ","reaction class"))
sum_recall=0
sum_prec=0
for i, rxn_class_label in enumerate(_sorted_rxn_label):
recall = 0
if rowCounts[i] > 0:
recall = float(cmat[i,i])/rowCounts[i]
sum_recall += recall
prec = 0
if colCounts[i] > 0:
prec = float(cmat[i,i])/colCounts[i]
sum_prec += prec
f_score = 0
if (recall + prec) > 0:
f_score = 2 * (recall * prec) / (recall + prec)
print('%2d % .4f % .4f % .4f % 9s %s'%(i, recall, prec, f_score, rxn_class_label, _names_rTypes[rxn_class_label]))
mean_recall = sum_recall/len(_sorted_rxn_label)
mean_prec = sum_prec/len(_sorted_rxn_label)
if (mean_recall + mean_prec) > 0:
mean_fscore = 2*(mean_recall*mean_prec)/(mean_recall+mean_prec)
print("Mean:% 3.2f % 7.2f % 7.2f"%(mean_recall,mean_prec,mean_fscore))
return cmat
cmat_rFP_agentFeature = evaluateModel(model, X_test, Y_test, rtypes, names_rTypes)
def labelled_cmat(cmat, labels, figsize=(20,15), labelExtras=None, dpi=300, threshold=0.01, xlabel=True, ylabel=True, rotation=90):
rowCounts = np.array(sum(cmat,1),dtype=float)
cmat_percent = cmat / rowCounts[:,None]
#zero all elements that are less than 1% of the row contents
ncm = cmat_percent*(cmat_percent>threshold)
fig, ax = plt.subplots(1,1, figsize=figsize)
pax=ax.pcolor(ncm,cmap=cm.ocean_r)
ax.set_frame_on(True)
# put the major ticks at the middle of each cell
ax.set_yticks(np.arange(cmat.shape[0])+0.5, minor=False)
ax.set_xticks(np.arange(cmat.shape[1])+0.5, minor=False)
# want a more natural, table-like display
ax.invert_yaxis()
ax.xaxis.tick_top()
if labelExtras is not None:
labels = [' %s %s'%(x,labelExtras[x].strip()) for x in labels]
ax.set_xticklabels([], minor=False)
ax.set_yticklabels([], minor=False)
if xlabel:
ax.set_xticklabels(labels, minor=False, rotation=rotation, horizontalalignment='left')
if ylabel:
ax.set_yticklabels(labels, minor=False)
ax.grid(True)
fig.colorbar(pax)
plt.axis('tight')
rowCounts = np.array(sum(cmat_rFP_agentFeature,1),dtype=float)
cmat_percent = cmat_rFP_agentFeature/rowCounts[:,None]
#zero all elements that are less than 1% of the row contents
ncm = cmat_percent*(cmat_percent>0.01)
fig, ax = plt.subplots(1,1, figsize=(20,15))
pax=ax.pcolor(ncm,cmap=cm.ocean_r)
ax.set_frame_on(True)
labels = [' %s %s'%(x,names_rTypes[x].strip()) for x in rtypes]
ax.set_yticks(np.arange(cmat_rFP_agentFeature.shape[0])+0.5, minor=False)
ax.set_xticks(np.arange(cmat_rFP_agentFeature.shape[1])+0.5, minor=False)
ax.set_xticklabels(labels, minor=False, rotation=90, horizontalalignment='left')
ax.set_yticklabels(labels, minor=False)
# want a more natural, table-like display
ax.invert_yaxis()
ax.xaxis.tick_top()
plt.show()