00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __itkLayerBase_h
00018 #define __itkLayerBase_h
00019
00020 #include <iostream>
00021 #include "itkLightProcessObject.h"
00022 #include "itkWeightSetBase.h"
00023 #include "itkArray.h"
00024 #include "itkVector.h"
00025 #include "itkTransferFunctionBase.h"
00026 #include "itkInputFunctionBase.h"
00027
00028 #include "itkMacro.h"
00029
00030 namespace itk
00031 {
00032 namespace Statistics
00033 {
00034
00035 template<class TVector, class TOutput>
00036 class LayerBase : public LightProcessObject
00037 {
00038
00039 public:
00040 typedef LayerBase Self;
00041 typedef LightProcessObject Superclass;
00042 typedef SmartPointer<Self> Pointer;
00043 typedef SmartPointer<const Self> ConstPointer;
00044
00046 itkTypeMacro(LayerBase, LightProcessObject);
00047
00048 typedef TVector InputVectorType;
00049 typedef TOutput OutputVectorType;
00050
00051 typedef typename TVector::ValueType ValueType;
00052 typedef ValueType* ValuePointer;
00053 typedef vnl_vector<ValueType> NodeVectorType;
00054 typedef Array<ValueType> InternalVectorType;
00055
00056 typedef WeightSetBase<TVector,TOutput> WeightSetType;
00057
00058 typedef TransferFunctionBase<ValueType> TransferFunctionType;
00059
00060 typedef InputFunctionBase<ValueType*, ValueType> InputFunctionType;
00061
00062 typedef typename InputFunctionType::Pointer InputFunctionPointer;
00063
00064 typedef typename TransferFunctionType::Pointer TransferFunctionPointer;
00065
00066 typedef typename WeightSetType::Pointer WeightSetPointer;
00067
00068 virtual void SetNumberOfNodes(unsigned int);
00069 unsigned int GetNumberOfNodes();
00070
00071 virtual ValueType GetInputValue(unsigned int) = 0;
00072 virtual ValueType GetOutputValue(int) = 0;
00073 virtual ValuePointer GetOutputVector() = 0;
00074
00075 virtual void ForwardPropagate(){};
00076
00077 virtual void ForwardPropagate(TVector){};
00078
00079 virtual void BackwardPropagate(InternalVectorType){};
00080
00081 virtual void BackwardPropagate(){};
00082 virtual ValueType GetOutputErrorValue(unsigned int) = 0;
00083 virtual void SetOutputErrorValues(TOutput) {};
00084
00085 virtual ValueType GetInputErrorValue(int) = 0;
00086 virtual ValuePointer GetInputErrorVector() = 0;
00087 virtual void SetInputErrorValue(ValueType, int) {};
00088
00089
00090 void SetInputWeightSet(WeightSetType*);
00091 itkGetObjectMacro(InputWeightSet, WeightSetType);
00092
00093
00094 void SetOutputWeightSet(WeightSetType*);
00095 itkGetObjectMacro(OutputWeightSet, WeightSetType);
00096
00097 void SetNodeInputFunction(InputFunctionType* f);
00098 itkGetObjectMacro(NodeInputFunction, InputFunctionType);
00099
00100 void SetTransferFunction(TransferFunctionType* f);
00101 itkGetObjectMacro(ActivationFunction, TransferFunctionType);
00102
00103 virtual ValueType Activation(ValueType) = 0;
00104 virtual ValueType DActivation(ValueType) = 0;
00105
00106 itkSetMacro(LayerType, unsigned int);
00107 itkGetMacro(LayerType, unsigned int);
00108
00109 itkSetMacro(LayerId,int);
00110 itkGetMacro(LayerId,int);
00111
00112 virtual void SetBias(ValueType) = 0;
00113 virtual ValueType GetBias() = 0;
00114
00115
00116 protected:
00117
00118 LayerBase();
00119 ~LayerBase();
00120
00122 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00123
00124 unsigned int m_LayerType;
00125 unsigned int m_LayerId;
00126 unsigned int m_NumberOfNodes;
00127
00128 typename WeightSetType::Pointer m_InputWeightSet;
00129 typename WeightSetType::Pointer m_OutputWeightSet;
00130
00131 TransferFunctionPointer m_ActivationFunction;
00132 InputFunctionPointer m_NodeInputFunction;
00133
00134
00135
00136 };
00137
00138 }
00139 }
00140
00141 #ifndef ITK_MANUAL_INSTANTIATION
00142 #include "itkLayerBase.txx"
00143 #endif
00144
00145 #endif
00146