mlpack  3.4.2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
base_layer.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
15 
16 #include <mlpack/prereqs.hpp>
30 
31 namespace mlpack {
32 namespace ann {
33 
60 template <
61  class ActivationFunction = LogisticFunction,
62  typename InputDataType = arma::mat,
63  typename OutputDataType = arma::mat
64 >
65 class BaseLayer
66 {
67  public:
72  {
73  // Nothing to do here.
74  }
75 
83  template<typename InputType, typename OutputType>
84  void Forward(const InputType& input, OutputType& output)
85  {
86  ActivationFunction::Fn(input, output);
87  }
88 
98  template<typename eT>
99  void Backward(const arma::Mat<eT>& input,
100  const arma::Mat<eT>& gy,
101  arma::Mat<eT>& g)
102  {
103  arma::Mat<eT> derivative;
104  ActivationFunction::Deriv(input, derivative);
105  g = gy % derivative;
106  }
107 
109  OutputDataType const& OutputParameter() const { return outputParameter; }
111  OutputDataType& OutputParameter() { return outputParameter; }
112 
114  OutputDataType const& Delta() const { return delta; }
116  OutputDataType& Delta() { return delta; }
117 
121  template<typename Archive>
122  void serialize(Archive& /* ar */, const unsigned int /* version */)
123  {
124  /* Nothing to do here */
125  }
126 
127  private:
129  OutputDataType delta;
130 
132  OutputDataType outputParameter;
133 }; // class BaseLayer
134 
135 // Convenience typedefs.
136 
140 template <
141  class ActivationFunction = LogisticFunction,
142  typename InputDataType = arma::mat,
143  typename OutputDataType = arma::mat
144 >
145 using SigmoidLayer = BaseLayer<
146  ActivationFunction, InputDataType, OutputDataType>;
147 
151 template <
152  class ActivationFunction = IdentityFunction,
153  typename InputDataType = arma::mat,
154  typename OutputDataType = arma::mat
155 >
156 using IdentityLayer = BaseLayer<
157  ActivationFunction, InputDataType, OutputDataType>;
158 
162 template <
163  class ActivationFunction = RectifierFunction,
164  typename InputDataType = arma::mat,
165  typename OutputDataType = arma::mat
166 >
167 using ReLULayer = BaseLayer<
168  ActivationFunction, InputDataType, OutputDataType>;
169 
173 template <
174  class ActivationFunction = TanhFunction,
175  typename InputDataType = arma::mat,
176  typename OutputDataType = arma::mat
177 >
178 using TanHLayer = BaseLayer<
179  ActivationFunction, InputDataType, OutputDataType>;
180 
184 template <
185  class ActivationFunction = SoftplusFunction,
186  typename InputDataType = arma::mat,
187  typename OutputDataType = arma::mat
188 >
189 using SoftPlusLayer = BaseLayer<
190  ActivationFunction, InputDataType, OutputDataType>;
191 
195 template <
196  class ActivationFunction = HardSigmoidFunction,
197  typename InputDataType = arma::mat,
198  typename OutputDataType = arma::mat
199 >
201  ActivationFunction, InputDataType, OutputDataType>;
202 
206 template <
207  class ActivationFunction = SwishFunction,
208  typename InputDataType = arma::mat,
209  typename OutputDataType = arma::mat
210 >
212  ActivationFunction, InputDataType, OutputDataType>;
213 
217 template <
218  class ActivationFunction = MishFunction,
219  typename InputDataType = arma::mat,
220  typename OutputDataType = arma::mat
221 >
223  ActivationFunction, InputDataType, OutputDataType>;
224 
228 template <
229  class ActivationFunction = LiSHTFunction,
230  typename InputDataType = arma::mat,
231  typename OutputDataType = arma::mat
232 >
234  ActivationFunction, InputDataType, OutputDataType>;
235 
239 template <
240  class ActivationFunction = GELUFunction,
241  typename InputDataType = arma::mat,
242  typename OutputDataType = arma::mat
243 >
245  ActivationFunction, InputDataType, OutputDataType>;
246 
250 template <
251  class ActivationFunction = ElliotFunction,
252  typename InputDataType = arma::mat,
253  typename OutputDataType = arma::mat
254 >
256  ActivationFunction, InputDataType, OutputDataType>;
257 
261 template <
262  class ActivationFunction = ElishFunction,
263  typename InputDataType = arma::mat,
264  typename OutputDataType = arma::mat
265 >
267  ActivationFunction, InputDataType, OutputDataType>;
268 
272 template <
273  class ActivationFunction = GaussianFunction,
274  typename InputDataType = arma::mat,
275  typename OutputDataType = arma::mat
276 >
278  ActivationFunction, InputDataType, OutputDataType>;
279 
280 } // namespace ann
281 } // namespace mlpack
282 
283 #endif
The identity function, defined by.
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: base_layer.hpp:84
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: base_layer.hpp:111
BaseLayer()
Create the BaseLayer object.
Definition: base_layer.hpp:71
The LiSHT function, defined by.
OutputDataType & Delta()
Modify the delta.
Definition: base_layer.hpp:116
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: base_layer.hpp:99
void serialize(Archive &, const unsigned int)
Serialize the layer.
Definition: base_layer.hpp:122
The tanh function, defined by.
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: base_layer.hpp:109
The ELiSH function, defined by.
Implementation of the base layer.
Definition: base_layer.hpp:65
BaseLayer< ActivationFunction, InputDataType, OutputDataType > SigmoidLayer
Standard Sigmoid-Layer using the logistic activation function.
Definition: base_layer.hpp:146
The Mish function, defined by.
The gaussian function, defined by.
The Elliot function, defined by.
The swish function, defined by.
The softplus function, defined by.
OutputDataType const & Delta() const
Get the delta.
Definition: base_layer.hpp:114
The hard sigmoid function, defined by.
The GELU function, defined by.
The rectifier function, defined by.