GANを使っていくつかのデータセットの学習と画像生成
はじめに
深層学習に関する知識が一昔前で止まってしまっているので、少しずつ追っていこうと思います。 今回はGANに関する内容です。
GAN(Generative Adversarial Network)を使用したモデルに対する敵対的なデータの生成は 従来の機械学習のようなデータの識別を行うのではなく、データの生成を行うという点で以前から興味があったので試してみようと思います。
GANは識別器と生成器の間でそれぞれ評価関数を設定し、双方がその評価関数の結果を良くするように互いに学習を行うような形式で行われます。 生成器は識別器が誤判別しそうなデータを生成しようとする(相手の評価関数の結果を出来るだけ悪くする)ので、敵対的な生成ネットワークと呼ばれます。 この場合判別というのは、学習対象の本物データセットか生成器によって生成されたデータであるかどうかということになります。
今回使用するデータセットは以下になります。 データセットはKaggleのDatasetsから取得したものになります。
- 絵画: https://www.kaggle.com/datasets/ikarus777/best-artworks-of-all-time
- ピクトグラム: https://www.kaggle.com/datasets/olgabelitskaya/art-pictogram
- 抽象画: https://www.kaggle.com/datasets/flash10042/abstract-paintings-dataset
- 風景画: https://www.kaggle.com/datasets/arnaud58/landscape-pictures
抽象画のデータは抽象なだけあって特徴がほとんどないようなものなので、 生成する画像の質を判断できるのかどうかという懸念がありますが試してみます。 ピクトグラムは64x64サイズで画像サイズが小さく、単純な描画であることが多いのでうまく生成してくれることを期待します。
また既存の事前学習モデルを使用して、画像を生成することも最後に行います。 事前学習済みモデルは独自の結果を生成することはできませんが、 訓練サンプルを収集して選別する作業と長時間の学習が省けるので手軽に比較的質の高い画像を生成することができます (解像度が高くかつノイズが少なく学習データに近い物を生成できるという意味で)。
DCGANでデータセットを使用した学習・画像生成
DCGANは敵対的生成ネットワークの一つですが、識別器に畳み込み層、生成器に逆(transpose)畳み込み層を使用します。
学習の流れとしてはまず識別器の学習ステップから行います。識別器に本物のデータを渡して伝播させ損失関数の計算を行います。 その後生成器側で生成した偽データも識別器に入力して損失の計算を行い、勾配を計算し、本物と偽物併せて勾配を計算してパラメータの更新を行います。 次いで生成器の学習ですが、先ほど作成した偽データを更新済みの識別器で判別しその結果の損失から勾配を計算して、パラメータの更新を行います。 この時、識別器側での偽データの判定を本物データのラベルと解釈して損失計算を行います。 以上を1ステップとして繰り返し学習を行います。
コードについてはほぼpytorch exampleのものを使用するので省略します。
絵画
https://www.kaggle.com/datasets/ikarus777/best-artworks-of-all-time のデータセットを使用します。
このデータセットはどの画家が描いた絵画かを識別するために作成されたセットで風景画や人物画、抽象画など多様な種類の絵画を有するデータセットになります。 合計8355枚になります。 200epochsの学習を行ったモデルで生成した画像例が以下になります
肖像画や風景画らしきものが生成されているのがわかります。
ピクトグラム
https://www.kaggle.com/datasets/olgabelitskaya/art-pictogram のデータセットを使用します。 ピクトグラムはその画像を見て、相手に情報を伝える役目を担っています。ですので画像は意図が相手に伝わりやすいように簡潔で分かりやすいものになっている傾向があります。 上記のデータには花や鳥、飛行機や車、船など16種の画像が計3545枚あります。
学習は400epochs行いました。 学習後に生成した画像を示します。
何かの絵っぽいように見えますが、何の絵かの判別は難しいです。 いろいろな種類が混ざったような絵にも見えるので、何か特定の絵を生成したいのならもう少しジャンルを絞るべきだったかもしれません。
画像を蝶の画像に絞って学習してみました。 学習に使用したデータが少ないので多様性に欠けますが、蝶の画像であることがわかるようになりました。
抽象画
https://www.kaggle.com/datasets/flash10042/abstract-paintings-dataset
これはWikiArt.orgから取得された抽象画14359枚のデータセットになります。 単純な幾何学的模様から、構図を意識したような複雑なものまであります。 200epochsの学習を行ったモデルで生成した画像例が以下になります
芸術の素養がないので抽象画としての質は判断はできませんが、素人目にはそれなりに抽象画にはなっているようです。
風景画
https://www.kaggle.com/datasets/arnaud58/landscape-pictures
これは山や海や砂漠などの風景の画像4319枚のデータセットになります。 200epochsの学習を行ったモデルで生成した画像例が以下になります。
風景画に近いものが生成されましたが、細部の表現は難しいようです。
事前学習モデルの使用(StyleGAN)
StyleGAN
StyleGANについて、元論文https://arxiv.org/abs/1812.04948 の緒言部分を読みましたが、 StyleGANは従来のGANに対して生成器のプロセスの制御をある程度可能にすることで高解像度と質の向上を可能にしたモデルのようです。
少しだけ詳細を述べると、従来のGANの生成器の入力で使用していた潜在変数zの代わりにMIP layerから生成したWに置き換え、 生成器ネットワークの各layerスケールにてWをy(styleと名付けられている)に変換して適応インスタンス正規化処理(AdaIN)を通じて、 画像のスタイルをダイレクトに制御することで、スケールに応じて画像の細部までの生成の制御を行うことを可能にしているようです。 最後に各layerにガウシアンノイズを加えて、確率的な生成手段を得ています。
おそらく潜在変数Wを直接いじることで画像生成の制御を行うことが出来るので、ここで希望の画像にするための微調整を行うのでしょう。
google colaboratoryを介して生成を行いましたが、StyleGAN https://github.com/NVlabs/stylegan.git をcloneして使用する場合、 tensorflow==1.15.2とtensorflow-gpu==1.15.2をインストールする必要があります。
!pip install tensorflow==1.15.2
!pip install tensorflow-gpu==1.15.2
DCGANで学習したものよりも解像度が高く細部にも対応しており質が高いようです。 ただStyleGANのサンプルの生成画像も見ましたが、ここまで鮮明な画像になるのも不思議に思えます。
まとめ
多様なデータセットを用いて学習を行いましたが、GANがどういったものかを少しずつつかめたように思います。 損失関数の結果から全体的に生成器よりも識別器の方が勝ってしまっている傾向があります。 これは生成器の学習は識別器を介した学習であり直観的には当然のように感じます。 ネットワークの規模の大きさから、精度を上げるためには生成したい画像の方向性を最初に決めてしまって学習するデータセットのジャンルを絞る必要があるように思います。 今回はGANを用いてお試しで学習を行いましたが、今後改善をするならば目的に合わせてまずデータセットの選別をよく考えることから始めるべきかと思います。 データセットに関してはGANに限った話でもなく機械学習一般的に言えることでもありますが。 また生成した画像の傾向としてはなんとなく似ている画像が多く、細部の表現が難しいようです。 そういう意味では比較的細かい部分が重要になる具体的な情報を示す画像よりは風景画や抽象画の方が向いているように感じます。 こういった部分はモデル設計から考慮する必要があるのかもしれません。
特に種々の不都合がなければ、事前学習モデルを使用するのがはやいですね。