13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
30 typename EnvironmentType,
35 class OneStepQLearningWorker
52 const UpdaterType& updater,
53 const EnvironmentType& environment,
57 #
if ENS_VERSION_MAJOR >= 2
60 environment(environment),
62 deterministic(deterministic),
63 pending(config.UpdateInterval())
72 updater(other.updater),
73 #
if ENS_VERSION_MAJOR >= 2
76 environment(other.environment),
78 deterministic(other.deterministic),
80 episodeReturn(other.episodeReturn),
81 pending(other.pending),
82 pendingIndex(other.pendingIndex),
83 network(other.network),
86 #if ENS_VERSION_MAJOR >= 2
87 updatePolicy =
new typename UpdaterType::template
88 Policy<arma::mat, arma::mat>(updater,
89 network.Parameters().n_rows,
90 network.Parameters().n_cols);
102 updater(std::move(other.updater)),
103 #
if ENS_VERSION_MAJOR >= 2
106 environment(std::move(other.environment)),
107 config(std::move(other.config)),
108 deterministic(std::move(other.deterministic)),
109 steps(std::move(other.steps)),
110 episodeReturn(std::move(other.episodeReturn)),
111 pending(std::move(other.pending)),
112 pendingIndex(std::move(other.pendingIndex)),
113 network(std::move(other.network)),
114 state(std::move(other.state))
116 #if ENS_VERSION_MAJOR >= 2
117 other.updatePolicy = NULL;
119 updatePolicy =
new typename UpdaterType::template
120 Policy<arma::mat, arma::mat>(updater,
121 network.Parameters().n_rows,
122 network.Parameters().n_cols);
136 #if ENS_VERSION_MAJOR >= 2
140 updater = other.updater;
141 environment = other.environment;
142 config = other.config;
143 deterministic = other.deterministic;
145 episodeReturn = other.episodeReturn;
146 pending = other.pending;
147 pendingIndex = other.pendingIndex;
148 network = other.network;
151 #if ENS_VERSION_MAJOR >= 2
152 updatePolicy =
new typename UpdaterType::template
153 Policy<arma::mat, arma::mat>(updater,
154 network.Parameters().n_rows,
155 network.Parameters().n_cols);
173 #if ENS_VERSION_MAJOR >= 2
177 updater = std::move(other.updater);
178 environment = std::move(other.environment);
179 config = std::move(other.config);
180 deterministic = std::move(other.deterministic);
181 steps = std::move(other.steps);
182 episodeReturn = std::move(other.episodeReturn);
183 pending = std::move(other.pending);
184 pendingIndex = std::move(other.pendingIndex);
185 network = std::move(other.network);
186 state = std::move(other.state);
188 #if ENS_VERSION_MAJOR >= 2
189 other.updatePolicy = NULL;
191 updatePolicy =
new typename UpdaterType::template
192 Policy<arma::mat, arma::mat>(updater,
193 network.Parameters().n_rows,
194 network.Parameters().n_cols);
205 #if ENS_VERSION_MAJOR >= 2
216 #if ENS_VERSION_MAJOR == 1
217 updater.Initialize(learningNetwork.Parameters().n_rows,
218 learningNetwork.Parameters().n_cols);
222 updatePolicy =
new typename UpdaterType::template
223 Policy<arma::mat, arma::mat>(updater,
224 learningNetwork.Parameters().n_rows,
225 learningNetwork.Parameters().n_cols);
229 network = learningNetwork;
243 bool Step(NetworkType& learningNetwork,
244 NetworkType& targetNetwork,
250 arma::colvec actionValue;
251 network.Predict(state.Encode(), actionValue);
252 ActionType action = policy.Sample(actionValue, deterministic);
254 double reward = environment.Sample(state, action, nextState);
255 bool terminal = environment.IsTerminal(nextState);
257 episodeReturn += reward;
260 terminal = terminal || steps >= config.
StepLimit();
265 totalReward = episodeReturn;
268 network = learningNetwork;
278 pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
284 arma::mat totalGradients(learningNetwork.Parameters().n_rows,
285 learningNetwork.Parameters().n_cols, arma::fill::zeros);
286 for (
size_t i = 0; i < pending.size(); ++i)
291 arma::colvec actionValue;
294 targetNetwork.Predict(
295 std::get<3>(transition).Encode(), actionValue);
297 double targetActionValue = actionValue.max();
298 if (terminal && i == pending.size() - 1)
299 targetActionValue = 0;
300 targetActionValue = std::get<2>(transition) +
301 config.
Discount() * targetActionValue;
304 arma::mat input = std::get<0>(transition).Encode();
305 network.Forward(input, actionValue);
306 actionValue[std::get<1>(transition).action] = targetActionValue;
310 network.Backward(input, actionValue, gradients);
313 totalGradients += gradients;
317 totalGradients.transform(
319 {
return std::min(std::max(gradient, -config.
GradientLimit()),
323 #if ENS_VERSION_MAJOR == 1
324 updater.Update(learningNetwork.Parameters(), config.
StepSize(),
327 updatePolicy->Update(learningNetwork.Parameters(),
332 network = learningNetwork;
341 { targetNetwork = learningNetwork; }
348 totalReward = episodeReturn;
365 state = environment.InitialSample();
370 #if ENS_VERSION_MAJOR >= 2
371 typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
375 EnvironmentType environment;
378 TrainingConfig config;
387 double episodeReturn;
390 std::vector<TransitionType> pending;
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
~OneStepQLearningWorker()
Clean memory.
std::tuple< StateType, ActionType, double, StateType > TransitionType
OneStepQLearningWorker & operator=(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
OneStepQLearningWorker(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
OneStepQLearningWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step Q-Learning worker with the given parameters and environment.
double Discount() const
Get the discount rate for future reward.
Forward declaration of OneStepQLearningWorker.
double GradientLimit() const
Get the limit of update gradient.
if(NOT R_FOUND OR NOT R_RCPP OR NOT R_RCPPARMADILLO OR NOT R_RCPPENSMALLEN OR NOT R_BH OR NOT R_ROXYGEN2 OR NOT R_TESTTHAT) if(NOT R_FOUND OR NOT R_RCPP OR NOT R_RCPPARMADILLO OR NOT R_RCPPENSMALLEN OR NOT R_BH OR NOT R_ROXYGEN2 OR NOT R_TESTTHAT) file(READ"$
size_t StepLimit() const
Get the maximum steps of each episode.
typename EnvironmentType::Action ActionType
OneStepQLearningWorker(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
OneStepQLearningWorker & operator=(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
size_t UpdateInterval() const
Get the update interval.
double StepSize() const
Get the step size of the optimizer.
typename EnvironmentType::State StateType
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.