Classification Experiment Design - stopping training

2 views (last 30 days)
Hi,
I have a question (or several) regarding experiment design using matlab to conduct binary classification. I have a "design" set (for training and validation) and a separate test set to evaluate generalisation. My problem is when to stop training and apply the resulting net to my test set. From the nnet faq "Statisticians tend to be skeptical of stopped training because it appears to be statistically inefficient due to the use of the split-sample technique". So I use cross-validation by the following general method:
  • for a given node value (H) I create a 100 random starting weight sets (S)
  • for each (S) I randomly divide the design set into k equally sized, mutually exclusive subsets and train K nets using K(i) as the validation set and K-K(i) as the training set.
  • Each net is trained to stop at mse_goal = 1e-6
  • I evaluate validation error for each K(i) and retrain the relevant net to the number of epochs where validation error was lowest ?
  • (do I need to do this or can I somehow select/return the net with the weights trained to this epoch from the [net tr] output) ?????
  • I apply this net to my test set to evaluate generalisation
  • the net from the set of S*K gives with the lowest generalisation error gives me the best trained net for my given H using my available data
Does this make sense ??
  1 Comment
Greg Heath
Greg Heath on 20 Apr 2013
You failed to give 4 important values
1. N-size of data set
2. I-input dimensionality
3. O-output dimensionality
4. MSE00-mean target variance mean(var(target'))

Sign in to comment.

Accepted Answer

Greg Heath
Greg Heath on 20 Apr 2013
1. Initialize the RNG before the H loop and record the current RNG seed at the beginning of each inner loop. You can retrain any individual net by knowing the corresponding values of H, ntrial and SEED.
2. Train with k-2 nets Using MSEtrngoal = max(0,0.01*Ndof*MSE00a/Ntrneq).
3. Record MSEtrn and MSEtrna for the combined training set and BOTH MSEs of the nontraining subsets at the MSE minima of EACH nontraining set:
a. MSEk-1 at the minimum of MSEk is an unbiased estimate of generalization
error.
b. So is MSEk at the minimum of MSEk-1.
c. So is their average MSEtst = ( MSEk + MSEk-1)/2
4. Calculate the summary stats (min,median,mean,stdv,max)of R2trn, R2trna, R2tst and plot as a function of H.
Hope this helps.
Thank you for formally accepting my answer
Greg

More Answers (0)

Categories

Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!