The built-in weight update algorithms and the function needed to register new ones are located in train.c. You should start by defining a new weight update procedure. This is a function called after each batch of examples has been processed. Its job is to update the weight on each link based on its accumulated error derivatives.
To begin, copy the function steepestUpdateWeights()
. You
will probably want to put your new function in extension.c. The
function takes one argument, the boolean doStats. If doStats is
true, it means that a report is about to be produced and the update
function should calculate the network's gradientLinearity and
weightCost while it is updating the weights. If a report will
not be produce, it would be a waste of time to do these calculations.
The weight update procedure is further complicated by checks for FROZEN structures. A network can be frozen at any level from the network down to the block. If a group is frozen, its units will not be traversed during the weight update to save time. Therefore, frozen links will not contribute to the weightCost or gradientLinearity.
Because traversing the weights in the network is relatively complicated, the UPDATE_WEIGHTS macro was written. This takes one argument, a block of code that is run on each link. UPDATE_WEIGHTS provides a number of variables that you can use within this code, including learningRate, momentum, and weightDecay. L will be a pointer to the Link structure for the link and M will be a pointer to its Link2 structure. Some variables can also be written and read as necessary, including lastWeightDelta and deriv.
The UPDATE_WEIGHTS macro takes care of calculating the statistics needed when doStats is true.
If your update function doesn't need any link fields other than the weight, deriv, and lastWeightDelta (this is used for the gradient linearity and for reference on the graphical displays), you can compile your algorithm in the standard version of Lens. However, if you need an extra link field, you can use the lastValue field but you must only allow your function to be compiled in the ADVANCED version of Lens. To do this, all code which refers to the algorithm should be surrounded with the following:
#ifdef ADVANCED ... #endif /* ADVANCED */
These sections will be ignored if Lens is compiled without the ADVANCED macro defined.
Next you will need to choose a bit mask for your algorithm. Here are the predefined masks:
#define STEEPEST ((mask) 1 << 0) #define MOMENTUM ((mask) 1 << 1) #define DOUGS_MOMENTUM ((mask) 1 << 2) #define DELTA_BAR_DELTA ((mask) 1 << 3)
You should put your new definition in type.h. It is a good idea to skip
a few bits after the last predefined one to leave room for new
additions. You may as well start with ((mask) 1 << 10)
.
The last step in creating an algorithm is to register it. To do this,
you should place a call to registerAlgorithm()
in
userInit()
in extension.c. Here is the call used to
register the steepest descent algorithm:
registerAlgorithm(STEEPEST, "steepest", "Steepest Descent", steepestUpdateWeights);
The first argument is the type you defined in type.h. The second argument is the short name. This will be the name of the shell command that will be created to train using your algorithm. It will also be the switch that can be given to the train command. The third argument is the long name. This is the name that will appear in messages and on the main display (a radio button for your algorithm will be created). The last argument is the weight update function you created.