Neural Network cross validation

6 views (last 30 days)
Anitha
Anitha on 13 Mar 2014
Commented: Yogini Prabhu on 24 May 2021
I am new to matlab. I have implemented a character recognition system using neural networks.Now, I am trying to do a 10 fold cross validation scheme for neural networks. I have done the following code.But i dont know if it is correct. Pls help me.
close all
clear all
load inputdata
load targetdata
inputs = input;
targets = target;
% Create a Pattern Recognition Network
hiddenLayerSize = 30;
net = patternnet(hiddenLayerSize);
% Choose Input and Output Pre/Post-Processing Functions
% For a list of all processing functions type: help nnprocess
net.inputs{1}.processFcns = {'removeconstantrows','mapminmax'};
net.outputs{2}.processFcns = {'removeconstantrows','mapminmax'};
k=10;
groups=[1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 4 4 4 4 4 4 4 4 4 4 5 5 5 5 5 5 5 5 5 5 6 6 6 6 6 6 6 6 6 6 7 7 7 7 7 7 7 7 7 7 8 8 8 8 8 8 8 8 8 8 9 9 9 9 9 9 9 9 9 9 10 10 10 10 10 10 10 10 10 10 11 11 11 11 11 11 11 11 11 11 12 12 12 12 12 12 12 12 12 12 13 13 13 13 13 13 13 13 13 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13 1 2 3 4 5 6 7 8 9 10 11 12 13]; %target
cvFolds = crossvalind('Kfold', groups, k); %# get indices of 10-fold CV
for i = 1:k %# for each fold
testIdx = (cvFolds == i); %# get indices of test instances
trainIdx = ~testIdx ; %# get indices training instances
trInd=find(trainIdx)
tstInd=find(testIdx)
net.trainFcn = 'trainbr'
net.trainParam.epochs = 100;
net.divideFcn = 'divideind';
net.divideParam.trainInd=trInd
net.divideParam.testInd=tstInd
% Choose a Performance Function
net.performFcn = 'mse'; % Mean squared error
% Train the Network
[net,tr] = train(net,inputs,targets);
%# test using test instances
outputs = net(inputs);
errors = gsubtract(targets,outputs);
performance = perform(net,targets,outputs)
trainTargets = targets .* tr.trainMask{1};
testTargets = targets .* tr.testMask{1};
trainPerformance = perform(net,trainTargets,outputs)
testPerformance = perform(net,testTargets,outputs)
test(k)=testPerformance;
save net
figure, plotconfusion(targets,outputs)
end
accuracy=mean(test);
% View the Network
view(net)
  1 Comment
Greg Heath
Greg Heath on 13 Mar 2014
When I cut and paste your code into the command line it does not run because it is not properly formatted.
I suggest that you reformat your post so that it will run when cut and pasted.
However, I do not have crossvalind or crossperf, so I'm not sure how much help I can be.
Reformat and I will try to do what I can.

Sign in to comment.

Accepted Answer

Greg Heath
Greg Heath on 14 Mar 2014
Edited: Greg Heath on 14 Mar 2014
Formatting not perfect; Did you cut and paste this version?
Results?
size(inputs) = ?
size(targets) = ?
Did you try to minimize hiddenlayer size?
Take net.divideFcn and net.trainFcn out of the loop
Trainbr uses regularization, not ordinary mse
Did you try default values of the remaining net.* specifications before overwriting them?
Initialize the RNG just before the loop so you can repeat your run if needed.
You have to configure the net at the top of the loop otherwise, weight initialization will only occur for the 1st net.
The train and test performances are already in tr. No need to recalculate.
You may have to use the Masks on both targets and outputs. Check to make sure.
If you save each net, they have to have different names. A 10 dimensional cell should work.
Why not save trainperformance also.
Then calculate min,median,mean,std and max of both train and test performances.
Why not run the iris_dataset before trying your own data?
If you use crossval and cvpartition we could compare results. HOWEVER, although I have them, I have never used them. it might just be easier if you used my crossval code in the NEWSGROUP
Also, index your confusion plots; otherwise they will overwrite the previous one
Hope this helps.
Thank you for formally accepting my answer
Greg

More Answers (0)

Categories

Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!