PyTorchでA3C
PyTorchについて
Torchをbackendに持つPyTorchというライブラリがついこの間公開されました.
PyTorchはニューラルネットワークライブラリの中でも動的にネットワークを生成するタイプのライブラリになっていて, 計算が呼ばれる度に計算グラフを保存しておきその情報をもとに誤差逆伝搬します.
動的ネットワークライブラリに対して静的ネットワークライブラリ(tensorflowなど)があります.
静的ネットワークライブラリでは計算グラフの形をまず構築し, そのネットワークに対してデータフィードを行うことで順伝搬, 逆伝搬計算を行います.
動的ネットワークライブラリは入力データに対してネットワーク構造を決めることができるのでRNNや再帰構造をもつネットワークなどが書きやすいという利点があります.
動的ネットワークといえばchainerが有名で, PFNが開発をしています.
今回使用するPyTorchのコードを見てみるとchainerのコードと非常に似通っています.
それもそのはずで
@brandondamos Yep, the PyTorch autograd codebase started with a fork from Chainer -- but then rewrote it in highly optimized pure C
— James Bradbury (@jekbradbury) 2017年1月18日
とあるように自動微分の部分はchainerが元になっているようです.
@chrisemoody @brandondamos @Smerity @TsendeeMTS @ChainerOfficial I’d love to do it eventually (not easy b/c CuPy and Torch are incompatible)
— James Bradbury (@jekbradbury) 2017年1月18日
とありますが, tweetにある通りtorchとcupyの抽象化は大変そうです.
コードを見てみた感じ, forwardとbackwardを定義するのは同じですがlossに対するbackwardの呼び出しがCで書かれていてpythonでbackwardを遡るchainerより速いのかなぁと思ってますが, 深層学習のボトルネックは行列演算とメモリコピーだと思うのでchainerが遡る際に無駄なメモリコピーをしてない限りそこまで速くならない気がしてます(実際そこがchainerの負荷になってるかはわからないので誰か教えてください).
今回PyTrochに注目したのはmultiprocessingがライブラリ内でサポートしてあるので非同期, 分散計算が楽に書けるかも?と思ったからです.
Pythonのmultiprocessing
CPythonではGIL(Global Interpreter Lock)というものを採用していて, 1つのプロセス内で実行できるのは1つのthreadのみとなっています.
そのため複数のコアを使い倒すプログラムを書く時はmultiprocessingモジュールを使用して複数のプロセスを立ち上げる必要があります.
その際複数プロセス間で何かしらの値を共有する必要があります.
multiprocessingモジュールではmultiprocessing.RawArrayを用いると与えた配列を共有メモリ空間上に配置してくれます.
さらにこの配列をnumpyのarrayとして利用したい場合はnumpy.frombufferに確保したRawArrayを渡して共有メモリ空間上に配置されたnumpyのndarrayを得ることができます.
この方法を使ってmultiprocessingなA3C(Asynchronous Advantage Actor-Critic)を実装してある例にmuupanさんのhttps://github.com/muupan/async-rlがあります.
このようにしてmultiprocessingでnumpyを共有することができるんですが, PyTorchではフレームワークで共有をサポートしていて, modelに対してshare_memoryを呼ぶだけでそのmodel内のparameterを共有メモリ上に配置してくれます.
要はRawArray〜, frombuffer〜ってやるのは面倒だし現代的じゃないと思ったのでフレームワークでカバーしてくれるのはいいなと思って試してみました.
A3C
上で紹介したレポジトリにあるA3Cですが, これを再現実装します.https://arxiv.org/abs/1602.01783
強化学習については今回あまり詳しく説明しないのですが, A3Cで重要な部分だけかいつまんで説明します.
A3Cでは各エージェントが各プロセスで行動を行い(報酬, 観測, 行動)の組を集め, その情報を元にローカルのパラメーターに対する勾配を計算し, それを元にグローバルのパラメーターを更新します.
更新後グローバルなパラメーターとローカルなパラメーターは同期され再度rolloutを行います.
共有されるネットワークと各プロセスに存在するローカルなネットワークで通信が行われます.
実装
まずニューラルネットワークを実装します.
A3Cでは方策関数と状態価値観数をニューラルネットワークで表現します.
PyTorchではnn.Moduleを継承したクラスのforwardを定義してあげるとcallableなLayerを作ることができます.
class Policy(nn.Module): def __init__(self, num_actions, dim_obs, frame_num=4): super(Policy, self).__init__() self.num_actions = num_actions self.dim_obs = dim_obs self.frame_num = frame_num self.fc1 = nn.Linear(dim_obs*frame_num, 128) self.fc2 = nn.Linear(128, 128) self.p = nn.Linear(128, num_actions) self.v = nn.Linear(128, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) policy = self.p(x) value = self.v(x) return F.softmax(policy), value def sync(self, global_module): for p, gp in zip(self.parameters(), global_module.parameters()): p.data = gp.data.clone()
このpolicyをglobalとlocalで2つ宣言してあげます.
global_policyは共有メモリ上に配置します.
global_policy = Policy(env.action_space.n, env.observation_space.shape[0], args.frame_num) global_policy.share_memory() local_policy = Policy(env.action_space.n, env.observation_space.shape[0], args.frame_num)
このあと各プロセスを起動して非同期更新を行います.
勾配計算はtrain関数内で行われます.
def train(rank, global_policy, local_policy, optimizer, env, global_t, args): o = env.reset() step = 0 sum_rewards = 0 max_sum_rewards = 0 while global_t[0] < args.epoch: local_policy.sync(global_policy) observations = [] actions = [] values = [] rewards = [] probs = [] R = 0 for i in range(args.local_t_max): global_t += 1 step += 1 p, v = local_policy(o) a = p.multinomial() o, r, done, _ = env.step(a.data.squeeze()[0]) if rank == 0: sum_rewards += r if args.render: env.render() observations.append(o) actions.append(a) values.append(v) rewards.append(r) probs.append(p) if done: o = env.reset() if rank == 0: print('----------------------------------') print('total reward of the episode:', sum_rewards) print('----------------------------------') if args.save_mode == 'all': torch.save(local_policy, os.path.join(args.log_dir, args.save_name+"_{}.pkl".format(global_t[0]))) elif args.save_mode == 'last': torch.save(local_policy, os.path.join(args.log_dir, args.save_name+'.pkl')) elif args.save_mode == 'max': if max_sum_rewards < sum_rewards: torch.save(local_policy, os.path.join(args.log_dir, args.save_name+'.pkl')) max_sum_rewards = sum_rewards sum_rewards = 0 step = 0 break else: _, v = local_policy(o) R += v.data.squeeze()[0] returns = [] for r in rewards[::-1]: R = r + 0.99 * R returns.insert(0, R) returns = torch.Tensor(returns) if len(returns) > 1: returns = (returns-returns.mean()) / (returns.std()+args.eps) v_loss = 0 entropy = 0 for a, v, p, r in zip(actions, values, probs, returns): a.reinforce(r - v.data.squeeze()) _v_loss = nn.MSELoss()(v, Variable(torch.Tensor([r]))) v_loss += _v_loss entropy += (p * (p + args.eps).log()).sum() v_loss = v_loss * 0.5 * args.v_loss_coeff entropy = entropy * args.entropy_beta optimizer.zero_grad() final_node = [v_loss, entropy] + actions gradients = [torch.ones(1), torch.ones(1)] + [None] * len(actions) autograd.backward(final_node, gradients) new_lr = (args.epoch - global_t[0]) / args.epoch * args.lr optimizer.step(new_lr)
余談なんですが, このコードやPyTorchのexampleshttps://github.com/pytorch/examples/でのreinforce.pyやactor-critic.pyで使われているmethodにreinforceがあります.
reinforceはREINFORCEができます.
REINFORCEは確率分布に対して誤差逆伝搬法を使いたい時に使う手法で, サンプリングによって微分ができなくなる際に勾配を推定する手法です.
PyTorchではStochasticFunctionというクラスがあり, そこからサンプリングされたものに対してはREINFORCEを使うことができます.
これはlog(p)をlossにすればいいだけの話なんですが, こっちのほうがわかりやすくていいと思います.
最後はこのtrainをmultiprocessingで並列実行すると非同期に学習してくれます.
processes = [] for rank in range(args.num_process): p = mp.Process(target=train, args=(rank, global_policy, local_policy, optimizer, env, global_t, args)) p.start() processes.append(p) for p in processes: p.join()
結果
とりあえずCartPoleだけ試してみました.
まぁまぁ学習してそうでしょうか?
今回のコードとか
コードは全部https://github.com/rarilurelo/pytorch_a3cにあります.
PyTorchのインストールは本家ページhttp://pytorch.org/の通りにやるとすぐできます.
あとがき
今日ちょうど秋葉さんがchainerを分散並列で高速化した話https://www.youtube.com/watch?v=wPr-yuJjvFQを見てました.
その中でコミュニティの認識としては非同期更新より同期更新がやはりいいと言うことが述べられていました.
実際A3Cをfxデータに適用したことがあるんですが, 途中まではうまい感じで学習しても方策が壊れてしまうことが多々ありハイパラ調整は難しいです.
A3CがOn-Policyの学習アルゴリズムであることも関係していると思われるのですが, 一回変な方策ができるとそこから回復することはありません.
TRPOhttps://arxiv.org/abs/1502.05477では方策を崩さずに単調な増加を目指すために色々頑張っているよう思うのですがTRPOの強さを考えると強化学習の難しい課題なのではないかと思います.