@@ -67,18 +67,17 @@ def Interpretation(data,
6767 print ("***SMILES_X Interpreter starts...***\n \n " )
6868 np .random .seed (seed = 123 )
6969 seed_list = np .random .randint (int (1e6 ), size = k_fold_number ).tolist ()
70- # Train/validation/test data splitting - 80/10/10 % at random with diff. seeds for k_fold_number times
71- selection_seed = seed_list [k_fold_index ]
7270
7371 print ("******" )
74- print ("***Fold #{} initiated...***" .format (selection_seed ))
72+ print ("***Fold #{} initiated...***" .format (k_fold_index ))
7573 print ("******" )
7674
7775 print ("***Sampling and splitting of the dataset.***\n " )
76+ # Reproducing the data split of the requested fold (k_fold_index)
7877 x_train , x_valid , x_test , y_train , y_valid , y_test , scaler = \
7978 utils .random_split (smiles_input = data .smiles ,
8079 prop_input = np .array (data .iloc [:,1 ]),
81- random_state = selection_seed ,
80+ random_state = seed_list [ k_fold_index ] ,
8281 scaling = True )
8382
8483 np .savetxt (save_dir + 'smiles_train.txt' , np .asarray (x_train ), newline = "\n " , fmt = '%s' )
@@ -145,7 +144,7 @@ def Interpretation(data,
145144 train_unique_tokens .insert (0 ,'pad' )
146145
147146 # Tokens as a list
148- tokens = token .get_vocab (input_dir + data_name + '_tokens_set_seed ' + str (selection_seed )+ '.txt' )
147+ tokens = token .get_vocab (input_dir + data_name + '_tokens_set_fold_ ' + str (k_fold_index )+ '.txt' )
149148 # Add 'pad', 'unk' tokens to the existing list
150149 tokens , vocab_size = token .add_extra_tokens (tokens , vocab_size )
151150
@@ -160,7 +159,7 @@ def Interpretation(data,
160159 int_to_token = token .get_inttotoken (tokens )
161160
162161 # Best architecture to visualize from
163- model_topredict = load_model (input_dir + 'LSTMAtt_' + data_name + '_model.best_seed_ ' + str (selection_seed )+ '.hdf5' ,
162+ model_topredict = load_model (input_dir + 'LSTMAtt_' + data_name + '_model.best_fold_ ' + str (k_fold_index )+ '.hdf5' ,
164163 custom_objects = {'AttentionM' : model .AttentionM ()})
165164 best_arch = [model_topredict .layers [2 ].output_shape [- 1 ]/ 2 ,
166165 model_topredict .layers [3 ].output_shape [- 1 ],
@@ -179,7 +178,7 @@ def Interpretation(data,
179178 print ("\n " )
180179
181180 print ("***Interpretation from the best model.***\n " )
182- model_att .load_weights (input_dir + 'LSTMAtt_' + data_name + '_model.best_seed_ ' + str (selection_seed )+ '.hdf5' )
181+ model_att .load_weights (input_dir + 'LSTMAtt_' + data_name + '_model.best_fold_ ' + str (k_fold_index )+ '.hdf5' )
183182 model_att .compile (loss = "mse" , optimizer = 'adam' , metrics = [metrics .mae ,metrics .mse ])
184183
185184 smiles_toviz_x_enum_tokens_tointvec = token .int_vec_encode (tokenized_smiles_list = smiles_toviz_x_enum_tokens ,
@@ -210,7 +209,7 @@ def Interpretation(data,
210209 fontsize = font_size ,
211210 rotation = font_rotation )
212211 plt .yticks ([])
213- plt .savefig (save_dir + 'Interpretation_1D_' + data_name + '_seed_ ' + str (selection_seed )+ '.png' , bbox_inches = 'tight' )
212+ plt .savefig (save_dir + 'Interpretation_1D_' + data_name + '_fold_ ' + str (k_fold_index )+ '.png' , bbox_inches = 'tight' )
214213 #plt.show()
215214
216215 smiles_tmp = smiles_toviz_x_enum [ienumcard ]
@@ -233,7 +232,7 @@ def Interpretation(data,
233232 colorMap = 'Reds' ,
234233 contourLines = 10 ,
235234 alpha = 0.25 )
236- fig .savefig (save_dir + 'Interpretation_2D_' + data_name + '_seed_ ' + str (selection_seed )+ '.png' , bbox_inches = 'tight' )
235+ fig .savefig (save_dir + 'Interpretation_2D_' + data_name + '_fold_ ' + str (k_fold_index )+ '.png' , bbox_inches = 'tight' )
237236 #fig.show()
238237
239238 model_topredict .compile (loss = "mse" , optimizer = 'adam' , metrics = [metrics .mae ,metrics .mse ])
@@ -276,7 +275,7 @@ def Interpretation(data,
276275 rotation = font_rotation )
277276 plt .yticks (fontsize = 20 )
278277 plt .ylabel ('Temporal relative distance' , fontsize = 25 , labelpad = 15 )
279- plt .savefig (save_dir + 'Interpretation_temporal_' + data_name + '_seed_ ' + str (selection_seed )+ '.png' , bbox_inches = 'tight' )
278+ plt .savefig (save_dir + 'Interpretation_temporal_' + data_name + '_fold_ ' + str (k_fold_index )+ '.png' , bbox_inches = 'tight' )
280279 #plt.show()
281280##
282281
0 commit comments