mlpack  3.4.2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
one_step_q_learning_worker.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
15 
17 
18 namespace mlpack {
19 namespace rl {
20 
29 template <
30  typename EnvironmentType,
31  typename NetworkType,
32  typename UpdaterType,
33  typename PolicyType
34 >
35 class OneStepQLearningWorker
36 {
37  public:
38  using StateType = typename EnvironmentType::State;
39  using ActionType = typename EnvironmentType::Action;
40  using TransitionType = std::tuple<StateType, ActionType, double, StateType>;
41 
52  const UpdaterType& updater,
53  const EnvironmentType& environment,
54  const TrainingConfig& config,
55  bool deterministic):
56  updater(updater),
57  #if ENS_VERSION_MAJOR >= 2
58  updatePolicy(NULL),
59  #endif
60  environment(environment),
61  config(config),
62  deterministic(deterministic),
63  pending(config.UpdateInterval())
64  { Reset(); }
65 
72  updater(other.updater),
73  #if ENS_VERSION_MAJOR >= 2
74  updatePolicy(NULL),
75  #endif
76  environment(other.environment),
77  config(other.config),
78  deterministic(other.deterministic),
79  steps(other.steps),
80  episodeReturn(other.episodeReturn),
81  pending(other.pending),
82  pendingIndex(other.pendingIndex),
83  network(other.network),
84  state(other.state)
85  {
86  #if ENS_VERSION_MAJOR >= 2
87  updatePolicy = new typename UpdaterType::template
88  Policy<arma::mat, arma::mat>(updater,
89  network.Parameters().n_rows,
90  network.Parameters().n_cols);
91  #endif
92 
93  Reset();
94  }
95 
102  updater(std::move(other.updater)),
103  #if ENS_VERSION_MAJOR >= 2
104  updatePolicy(NULL),
105  #endif
106  environment(std::move(other.environment)),
107  config(std::move(other.config)),
108  deterministic(std::move(other.deterministic)),
109  steps(std::move(other.steps)),
110  episodeReturn(std::move(other.episodeReturn)),
111  pending(std::move(other.pending)),
112  pendingIndex(std::move(other.pendingIndex)),
113  network(std::move(other.network)),
114  state(std::move(other.state))
115  {
116  #if ENS_VERSION_MAJOR >= 2
117  other.updatePolicy = NULL;
118 
119  updatePolicy = new typename UpdaterType::template
120  Policy<arma::mat, arma::mat>(updater,
121  network.Parameters().n_rows,
122  network.Parameters().n_cols);
123  #endif
124  }
125 
132  {
133  if (&other == this)
134  return *this;
135 
136  #if ENS_VERSION_MAJOR >= 2
137  delete updatePolicy;
138  #endif
139 
140  updater = other.updater;
141  environment = other.environment;
142  config = other.config;
143  deterministic = other.deterministic;
144  steps = other.steps;
145  episodeReturn = other.episodeReturn;
146  pending = other.pending;
147  pendingIndex = other.pendingIndex;
148  network = other.network;
149  state = other.state;
150 
151  #if ENS_VERSION_MAJOR >= 2
152  updatePolicy = new typename UpdaterType::template
153  Policy<arma::mat, arma::mat>(updater,
154  network.Parameters().n_rows,
155  network.Parameters().n_cols);
156  #endif
157 
158  Reset();
159 
160  return *this;
161  }
162 
169  {
170  if (&other == this)
171  return *this;
172 
173  #if ENS_VERSION_MAJOR >= 2
174  delete updatePolicy;
175  #endif
176 
177  updater = std::move(other.updater);
178  environment = std::move(other.environment);
179  config = std::move(other.config);
180  deterministic = std::move(other.deterministic);
181  steps = std::move(other.steps);
182  episodeReturn = std::move(other.episodeReturn);
183  pending = std::move(other.pending);
184  pendingIndex = std::move(other.pendingIndex);
185  network = std::move(other.network);
186  state = std::move(other.state);
187 
188  #if ENS_VERSION_MAJOR >= 2
189  other.updatePolicy = NULL;
190 
191  updatePolicy = new typename UpdaterType::template
192  Policy<arma::mat, arma::mat>(updater,
193  network.Parameters().n_rows,
194  network.Parameters().n_cols);
195  #endif
196 
197  return *this;
198  }
199 
204  {
205  #if ENS_VERSION_MAJOR >= 2
206  delete updatePolicy;
207  #endif
208  }
209 
214  void Initialize(NetworkType& learningNetwork)
215  {
216  #if ENS_VERSION_MAJOR == 1
217  updater.Initialize(learningNetwork.Parameters().n_rows,
218  learningNetwork.Parameters().n_cols);
219  #else
220  delete updatePolicy;
221 
222  updatePolicy = new typename UpdaterType::template
223  Policy<arma::mat, arma::mat>(updater,
224  learningNetwork.Parameters().n_rows,
225  learningNetwork.Parameters().n_cols);
226  #endif
227 
228  // Build local network.
229  network = learningNetwork;
230  }
231 
243  bool Step(NetworkType& learningNetwork,
244  NetworkType& targetNetwork,
245  size_t& totalSteps,
246  PolicyType& policy,
247  double& totalReward)
248  {
249  // Interact with the environment.
250  arma::colvec actionValue;
251  network.Predict(state.Encode(), actionValue);
252  ActionType action = policy.Sample(actionValue, deterministic);
253  StateType nextState;
254  double reward = environment.Sample(state, action, nextState);
255  bool terminal = environment.IsTerminal(nextState);
256 
257  episodeReturn += reward;
258  steps++;
259 
260  terminal = terminal || steps >= config.StepLimit();
261  if (deterministic)
262  {
263  if (terminal)
264  {
265  totalReward = episodeReturn;
266  Reset();
267  // Sync with latest learning network.
268  network = learningNetwork;
269  return true;
270  }
271  state = nextState;
272  return false;
273  }
274 
275  #pragma omp atomic
276  totalSteps++;
277 
278  pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
279  pendingIndex++;
280 
281  if (terminal || pendingIndex >= config.UpdateInterval())
282  {
283  // Initialize the gradient storage.
284  arma::mat totalGradients(learningNetwork.Parameters().n_rows,
285  learningNetwork.Parameters().n_cols, arma::fill::zeros);
286  for (size_t i = 0; i < pending.size(); ++i)
287  {
288  TransitionType &transition = pending[i];
289 
290  // Compute the target state-action value.
291  arma::colvec actionValue;
292  #pragma omp critical
293  {
294  targetNetwork.Predict(
295  std::get<3>(transition).Encode(), actionValue);
296  };
297  double targetActionValue = actionValue.max();
298  if (terminal && i == pending.size() - 1)
299  targetActionValue = 0;
300  targetActionValue = std::get<2>(transition) +
301  config.Discount() * targetActionValue;
302 
303  // Compute the training target for current state.
304  arma::mat input = std::get<0>(transition).Encode();
305  network.Forward(input, actionValue);
306  actionValue[std::get<1>(transition).action] = targetActionValue;
307 
308  // Compute gradient.
309  arma::mat gradients;
310  network.Backward(input, actionValue, gradients);
311 
312  // Accumulate gradients.
313  totalGradients += gradients;
314  }
315 
316  // Clamp the accumulated gradients.
317  totalGradients.transform(
318  [&](double gradient)
319  { return std::min(std::max(gradient, -config.GradientLimit()),
320  config.GradientLimit()); });
321 
322  // Perform async update of the global network.
323  #if ENS_VERSION_MAJOR == 1
324  updater.Update(learningNetwork.Parameters(), config.StepSize(),
325  totalGradients);
326  #else
327  updatePolicy->Update(learningNetwork.Parameters(),
328  config.StepSize(), totalGradients);
329  #endif
330 
331  // Sync the local network with the global network.
332  network = learningNetwork;
333 
334  pendingIndex = 0;
335  }
336 
337  // Update global target network.
338  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
339  {
340  #pragma omp critical
341  { targetNetwork = learningNetwork; }
342  }
343 
344  policy.Anneal();
345 
346  if (terminal)
347  {
348  totalReward = episodeReturn;
349  Reset();
350  return true;
351  }
352  state = nextState;
353  return false;
354  }
355 
356  private:
360  void Reset()
361  {
362  steps = 0;
363  episodeReturn = 0;
364  pendingIndex = 0;
365  state = environment.InitialSample();
366  }
367 
369  UpdaterType updater;
370  #if ENS_VERSION_MAJOR >= 2
371  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
372  #endif
373 
375  EnvironmentType environment;
376 
378  TrainingConfig config;
379 
381  bool deterministic;
382 
384  size_t steps;
385 
387  double episodeReturn;
388 
390  std::vector<TransitionType> pending;
391 
393  size_t pendingIndex;
394 
396  NetworkType network;
397 
399  StateType state;
400 };
401 
402 } // namespace rl
403 } // namespace mlpack
404 
405 #endif
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
std::tuple< StateType, ActionType, double, StateType > TransitionType
OneStepQLearningWorker & operator=(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
OneStepQLearningWorker(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
OneStepQLearningWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step Q-Learning worker with the given parameters and environment.
double Discount() const
Get the discount rate for future reward.
Forward declaration of OneStepQLearningWorker.
double GradientLimit() const
Get the limit of update gradient.
if(NOT R_FOUND OR NOT R_RCPP OR NOT R_RCPPARMADILLO OR NOT R_RCPPENSMALLEN OR NOT R_BH OR NOT R_ROXYGEN2 OR NOT R_TESTTHAT) if(NOT R_FOUND OR NOT R_RCPP OR NOT R_RCPPARMADILLO OR NOT R_RCPPENSMALLEN OR NOT R_BH OR NOT R_ROXYGEN2 OR NOT R_TESTTHAT) file(READ"$
Definition: CMakeLists.txt:61
size_t StepLimit() const
Get the maximum steps of each episode.
typename EnvironmentType::Action ActionType
OneStepQLearningWorker(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
OneStepQLearningWorker & operator=(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
size_t UpdateInterval() const
Get the update interval.
double StepSize() const
Get the step size of the optimizer.
typename EnvironmentType::State StateType
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.