PyLearnMLCR 03: ニューラルネットワーク

みんな大好きなニューラルネットワークです。SVMや決定木とは異なり、最適化するパラメータが莫大にあるので、全探索やある種の方程式によって最適解を見つけることはほぼ不可能です。そのため、最急勾配降下法(あるいはそれを応用したもの)と呼ばれる、微分係数を利用してまあまあ良いパラメータを探す、という方法が一般的にとられます。微分を利用してパラメータを最適化する場合、採用する初期値や使用する最適化手法の条件により、最終的に出来上がるモデルが毎回大きく異なることが多いです。以上の理由により、SVMや決定木よりも得られるモデルが極めて不安定であるというデメリットがあります。一方、最近では良いパラメータに収束させる最適化アルゴリズムも発展して来たので、有効に活用することもできるようになってきました。

なおニューラルネットワークは、教師データが少ないとき有効な学習を行えないことが多いので注意してください。SVMなどはその点の対策が比較的行われています。

分類問題

ダミーデータの生成

決定木での解説と同じデータなので、説明は省略します。

In [1]:
import numpy as np
np.random.seed(1) # 擬似乱数シード: 毎回同じ乱数を出す

# Class1: 成績上位群
mu = [7, 7]
sigma = [[0.5, 0.05], [0.05, 0.5]]
Dat01 = np.random.multivariate_normal(mu, sigma, 100)
Label01=[]
for i in range(0, len(Dat01)):
    Label01.append("Good")

# Class 2: 成績中位群
mu = [5.5, 4]
sigma = [[0.8, -0.5], [-0.5, 0.8]]
Dat02 = np.random.multivariate_normal(mu, sigma, 100)
Label02=[]
for i in range(0, len(Dat02)):
    Label02.append("Middle")

# Class 3: 成績下位群
mu = [2, 5]
sigma = [[0.3, 0.0], [0.0, 3]]
Dat03 = np.random.multivariate_normal(mu, sigma, 100)
Label03=[]
for i in range(0, len(Dat01)):
    Label03.append("Bad")

散布図によるデータ確認

続いて、散布図でデータを可視化してみます。これも決定木と同じなので、説明は省略します。

In [3]:
import matplotlib.pyplot as plt

x = [Dat01[:,0], Dat02[:,0], Dat03[:,0]]
y = [Dat01[:,1], Dat02[:,1], Dat03[:,1]]

plt.figure(figsize=(5, 4)) # figureの縦横の大きさ

# Goodの散布図
plt.scatter(Dat01[:,0], Dat01[:,1], s=50, c='blue', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')

# Middleの散布図
plt.scatter(Dat02[:,0], Dat02[:,1], s=50, c='orange', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')

# Badの散布図
plt.scatter(Dat03[:,0], Dat03[:,1], s=50, c='red', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')

plt.title("Relationship: X1, X2 and Score")  # タイトル
plt.xlabel("X1: Studying Time") # 軸名
plt.ylabel("X2: Understanding") #軸名
plt.grid(True)      #グリッド線(True:引く、False:引かない)
plt.xlim(0, 10)  # 横軸最小最大
plt.ylim(0, 10)  # 縦軸最小最大

plt.legend(["Good", "Middle", "Bad"], loc="upper right") # 凡例
Out[3]:
<matplotlib.legend.Legend at 0x11145d860>

教師データセットの作成

これも決定木のときと同じです。Xの1列目には勉強時間、2列目には理解傾向が格納されています。そして、YにはGood, Middle, Badのラベルが導入されています。それぞれ100人分なので、合計300人分の勉強の仕方と学力が入っているデータセットになります。

In [4]:
X=np.concatenate([Dat01, Dat02, Dat03]) # 問題
Y=np.concatenate([Label01, Label02, Label03]) # 正解ラベル

ニューラルネットワークの学習

それでは、ニューラルネットワークの構築方法について説明していきます。ハイパーパラメータがたくさんあります。主要なものをここに示します。概要だけ書きますので、人工知能が専門と言い張りたい人は、自分で調べて勉強してみてください。

  • hidden_layer_sizes (隠れ層のニューロン数):

    • ニューラルネットワークは、入力層、隠れ層、出力層があります。入力層は入力するデータの次元が対応しますので、自動的に決定されます(例えば、今回は2次元となります)。出力層は、分類されるクラス数となるので、これも自動的に決定されます(例えば、今回はBad, Middle, Goodの3択問題を解くので3次元となります)。隠れ層は、分析者が決めなくてはいけません。以下のように入力すると、パラメータのセットができます。
    • hidden_layer_sizes = (10, ): 2層目の隠れ層は10個の、3層型NN
    • hidden_layer_sizes = (10, 5, ): 2層目の隠れ層は10個、3層目の隠れ層のニューロンは5個の、4層型NN
  • activation [default relu](活性化関数):

    • 'identity'(線形), 'logistic'(シグモイド関数), 'relu'(ReLu関数)のいずれかを選択します。線形は学習が高速ですが、特徴量空間上の識別境界が直線にしかなれないというデメリットがあります。多くの場合、どれが良いかはトライアンドエラーで探します。いっぱい作っていっぱい試しましょう。
  • solver [default adam] (最適化手法):

    • lbfgs(ニュートン法), sgd(確率的勾配降下法), adam(Adaptive moment estimation)のいずれかを選択します。多くの場合、sgdかadamが良いですが、良い値をトライアンドエラーで探すのが普通です。
  • alpha [default 0.0001](L2正則化の重み):

    • NN学習で用いる評価関数は通常、実測と推定の差を減らす、というものが基本方針となっています。これに対して、NNのウェイトとバイアスパラメータを二乗し、総和を計算し、この値を減らそうとする考えを、L2正則化と言います。ウェイト・バイアスパラメータが安定するので、NNの出力が飛び抜けて変な値を出すことを保障しようとしています。alphaはL2正則化をどれだけ重要視するか、というパラメータとなります。NNの評価関数は、「実測と推定の差」と「L2正則化」の和により定義されるので、片方を重要視すれば、もう片方の重要度が下がります。バランスが重要ですね。いくつが良いかは問題によって変わります。よくわからない場合は、規定値(何も入力を与えない)を採用して良いかと思います。
  • learning_rate_init [default 0.001] (学習係数):

    • NNの学習では、独立変数をウェイト・バイアスパラメータ、従属変数を実測と推定の差により定義される関数に対し、その偏微分を計算します。そして、ウェイト・バイアスパラメータから微分係数を引くことでパラメータを更新していきます。このとき、微分係数をそのまま引いてしまうとパラメータが変化しすぎるという問題が生じます。そのため、微分係数に極めて小さな数字を乗じて、更新量を下げることが普通です。この数値は、学習の大きさを規定することになりますので、学習係数と呼ばれます。学習係数が大きなときは1回の学習でどんどん性能が上がっていきますが、良い解を飛び越えてパラメータが変化してしまう場合があります。一方、学習係数が小さなときは学習時間がたくさんかかりますが、良い解を飛び越えにくくなります。どの値が一概に良いとは言いにくいので、はじめのうちは、学習が遅いと感じたら係数を高くするくらいの感覚で良いと思います。
  • max_iter [default 200](学習回数):

    • 確率的勾配降下法あるいはadamを利用した場合、微分によりパラメータが更新されていきます。ここでは、更新する回数を決めます。
  • random_state [default None](ランダムネス):

    • NNの学習は、乱数に依存させる箇所がいくつかあります。そのため、毎回実行するたびに結果が変わってしまいます。これを避けるため、乱数を生成するシードを固定させる必要があります。random_state=1というように、任意の整数を与えることで、乱数生成器を固定させることができます(別の数字を入れると、乱数生成器のシードが変わり、別の乱数で固定されます)。研究をしているとき、「ちょっと結果を再現してみてよ」と他の先生から言われたりするので、乱数生成器のシードは必ず設定してください。状況としてはあまりありませんが、毎回違う結果がいいという場合は、何も記入しなければokです。
  • tol [default $10^{-4}$](学習打ち切りの基準):

    • NNでは、評価関数の改善が一定未満になったとき、指定した学習回数に至っていなくても、学習を打ち切ります。ここでは、この基準を決定します。pow(10, -9)など($10^{-9}$のこと)小さな値を設定すると、打ち切りされにくくなります。大きな値を指定すると、すぐに学習が打ち切られます。
  • verbose [default False](学習状況可視化フラグ):

    • Falseの場合、学習状況を可視化しません。Trueの場合、学習状況を可視化します。NNは、一晩や1ヶ月など、極めて長い間学習をさせ続ける場合があります。この場合、Trueにして学習状況を可視化しないとやってられません。

それでは、学習を開始して見ます。活性化関数を線形、シグモイド、ReLuの3種類を用意して、その差を比較して見ます。そのほかのパラメータは同一で、最適化はadam、隠れ層は2層分の4層型ニューラルネットワーク(入出力層を含めると4になります)、学習回数1000回、学習係数は0.01です。乱数のシードは固定させ、再現できるようにします。

In [5]:
from sklearn.neural_network import MLPClassifier

# *** モデルの初期設定 ***

# 活性化関数: 線形
NNmodel_Lin = MLPClassifier(hidden_layer_sizes=(10, 10, ), solver="adam", activation="identity", learning_rate_init=0.01, 
                             random_state=1, max_iter=1000, alpha=0.01, tol = pow(10, -5), verbose=True)
# 活性化関数: シグモイド関数
NNmodel_Sigm = MLPClassifier(hidden_layer_sizes=(10, 10, ), solver="adam", activation="logistic", learning_rate_init=0.01, 
                             random_state=1, max_iter=1000, alpha=0.01, tol = pow(10, -5), verbose=True)
# 活性化関数: ReLu関数
NNmodel_ReLu = MLPClassifier(hidden_layer_sizes=(10, 10, ), solver="adam", activation="relu", learning_rate_init=0.01, 
                            random_state=1, max_iter=1000, alpha=0.01, tol = pow(10, -5), verbose=True)

# 問題と正答ラベルを与え、学習開始
NNmodel_Lin.fit(X, Y)
NNmodel_Sigm.fit(X, Y)
NNmodel_ReLu.fit(X, Y)
Iteration 1, loss = 2.82639601
Iteration 2, loss = 1.85889591
Iteration 3, loss = 1.33939821
Iteration 4, loss = 1.29557297
Iteration 5, loss = 1.30089074
Iteration 6, loss = 1.12254658
Iteration 7, loss = 0.92289603
Iteration 8, loss = 0.84490530
Iteration 9, loss = 0.82244737
Iteration 10, loss = 0.80750658
Iteration 11, loss = 0.74860934
Iteration 12, loss = 0.67732644
Iteration 13, loss = 0.63538506
Iteration 14, loss = 0.62731704
Iteration 15, loss = 0.62049951
Iteration 16, loss = 0.59895284
Iteration 17, loss = 0.57246817
Iteration 18, loss = 0.55442966
Iteration 19, loss = 0.54375162
Iteration 20, loss = 0.52974274
Iteration 21, loss = 0.51719784
Iteration 22, loss = 0.51112092
Iteration 23, loss = 0.50575512
Iteration 24, loss = 0.49599619
Iteration 25, loss = 0.48409314
Iteration 26, loss = 0.47800793
Iteration 27, loss = 0.47313602
Iteration 28, loss = 0.46724797
Iteration 29, loss = 0.46045951
Iteration 30, loss = 0.45532337
Iteration 31, loss = 0.44913814
Iteration 32, loss = 0.44309843
Iteration 33, loss = 0.43820770
Iteration 34, loss = 0.43312889
Iteration 35, loss = 0.42847554
Iteration 36, loss = 0.42397279
Iteration 37, loss = 0.41869021
Iteration 38, loss = 0.41376097
Iteration 39, loss = 0.40947417
Iteration 40, loss = 0.40428563
Iteration 41, loss = 0.39834194
Iteration 42, loss = 0.39402810
Iteration 43, loss = 0.38920094
Iteration 44, loss = 0.38338079
Iteration 45, loss = 0.38013723
Iteration 46, loss = 0.37275522
Iteration 47, loss = 0.36732367
Iteration 48, loss = 0.36195571
Iteration 49, loss = 0.35604339
Iteration 50, loss = 0.34984697
Iteration 51, loss = 0.34508955
Iteration 52, loss = 0.33803677
Iteration 53, loss = 0.33226943
Iteration 54, loss = 0.33102519
Iteration 55, loss = 0.32152211
Iteration 56, loss = 0.31578699
Iteration 57, loss = 0.31187577
Iteration 58, loss = 0.30197054
Iteration 59, loss = 0.30110996
Iteration 60, loss = 0.29191518
Iteration 61, loss = 0.28471042
Iteration 62, loss = 0.28019233
Iteration 63, loss = 0.27567310
Iteration 64, loss = 0.26617054
Iteration 65, loss = 0.25972782
Iteration 66, loss = 0.25478442
Iteration 67, loss = 0.24877637
Iteration 68, loss = 0.24275826
Iteration 69, loss = 0.23864426
Iteration 70, loss = 0.23145300
Iteration 71, loss = 0.22678339
Iteration 72, loss = 0.22136944
Iteration 73, loss = 0.21523836
Iteration 74, loss = 0.21072532
Iteration 75, loss = 0.20480935
Iteration 76, loss = 0.20075362
Iteration 77, loss = 0.19521445
Iteration 78, loss = 0.18958144
Iteration 79, loss = 0.18629766
Iteration 80, loss = 0.18133614
Iteration 81, loss = 0.17597893
Iteration 82, loss = 0.17143062
Iteration 83, loss = 0.16799941
Iteration 84, loss = 0.16343665
Iteration 85, loss = 0.15940536
Iteration 86, loss = 0.15452229
Iteration 87, loss = 0.15083605
Iteration 88, loss = 0.14719311
Iteration 89, loss = 0.14434772
Iteration 90, loss = 0.14018838
Iteration 91, loss = 0.13639688
Iteration 92, loss = 0.13267381
Iteration 93, loss = 0.12938057
Iteration 94, loss = 0.12693537
Iteration 95, loss = 0.12380170
Iteration 96, loss = 0.12076159
Iteration 97, loss = 0.11808995
Iteration 98, loss = 0.11635684
Iteration 99, loss = 0.11302634
Iteration 100, loss = 0.11260356
Iteration 101, loss = 0.10739937
Iteration 102, loss = 0.10724062
Iteration 103, loss = 0.10243714
Iteration 104, loss = 0.10179176
Iteration 105, loss = 0.10058083
Iteration 106, loss = 0.09762324
Iteration 107, loss = 0.09680507
Iteration 108, loss = 0.09350728
Iteration 109, loss = 0.09424601
Iteration 110, loss = 0.09053055
Iteration 111, loss = 0.09215490
Iteration 112, loss = 0.08674792
Iteration 113, loss = 0.08854480
Iteration 114, loss = 0.08339394
Iteration 115, loss = 0.08449576
Iteration 116, loss = 0.08033105
Iteration 117, loss = 0.08093064
Iteration 118, loss = 0.08140598
Iteration 119, loss = 0.07880663
Iteration 120, loss = 0.07832073
Iteration 121, loss = 0.07426436
Iteration 122, loss = 0.07429381
Iteration 123, loss = 0.07343310
Iteration 124, loss = 0.07110414
Iteration 125, loss = 0.07101780
Iteration 126, loss = 0.06963808
Iteration 127, loss = 0.06853655
Iteration 128, loss = 0.06884529
Iteration 129, loss = 0.06662154
Iteration 130, loss = 0.06915149
Iteration 131, loss = 0.06487012
Iteration 132, loss = 0.06853378
Iteration 133, loss = 0.06331222
Iteration 134, loss = 0.06602335
Iteration 135, loss = 0.06387985
Iteration 136, loss = 0.06106268
Iteration 137, loss = 0.06270564
Iteration 138, loss = 0.06026884
Iteration 139, loss = 0.05939150
Iteration 140, loss = 0.05897728
Iteration 141, loss = 0.05814737
Iteration 142, loss = 0.05777524
Iteration 143, loss = 0.05631442
Iteration 144, loss = 0.05683570
Iteration 145, loss = 0.05568700
Iteration 146, loss = 0.05510061
Iteration 147, loss = 0.05468023
Iteration 148, loss = 0.05368109
Iteration 149, loss = 0.05337523
Iteration 150, loss = 0.05323944
Iteration 151, loss = 0.05246691
Iteration 152, loss = 0.05269409
Iteration 153, loss = 0.05126682
Iteration 154, loss = 0.05164959
Iteration 155, loss = 0.05175335
Iteration 156, loss = 0.05104538
Iteration 157, loss = 0.05194058
Iteration 158, loss = 0.05267451
Iteration 159, loss = 0.04983394
Iteration 160, loss = 0.04974249
Iteration 161, loss = 0.04864562
Iteration 162, loss = 0.04803113
Iteration 163, loss = 0.04816829
Iteration 164, loss = 0.04794936
Iteration 165, loss = 0.04789652
Iteration 166, loss = 0.04714749
Iteration 167, loss = 0.04646539
Iteration 168, loss = 0.04596376
Iteration 169, loss = 0.04576421
Iteration 170, loss = 0.04635041
Iteration 171, loss = 0.04482039
Iteration 172, loss = 0.04632948
Iteration 173, loss = 0.04590016
Iteration 174, loss = 0.04412231
Iteration 175, loss = 0.04542326
Iteration 176, loss = 0.04453559
Iteration 177, loss = 0.04344054
Iteration 178, loss = 0.04311598
Iteration 179, loss = 0.04291755
Iteration 180, loss = 0.04316951
Iteration 181, loss = 0.04252889
Iteration 182, loss = 0.04198282
Iteration 183, loss = 0.04272738
Iteration 184, loss = 0.04180607
Iteration 185, loss = 0.04188547
Iteration 186, loss = 0.04142040
Iteration 187, loss = 0.04086264
Iteration 188, loss = 0.04064369
Iteration 189, loss = 0.04074367
Iteration 190, loss = 0.04039832
Iteration 191, loss = 0.03987323
Iteration 192, loss = 0.03986162
Iteration 193, loss = 0.04004000
Iteration 194, loss = 0.03998042
Iteration 195, loss = 0.03984593
Iteration 196, loss = 0.03938069
Iteration 197, loss = 0.03945131
Iteration 198, loss = 0.03945853
Iteration 199, loss = 0.03841086
Iteration 200, loss = 0.03848597
Iteration 201, loss = 0.03862978
Iteration 202, loss = 0.03827257
Iteration 203, loss = 0.03773731
Iteration 204, loss = 0.03771194
Iteration 205, loss = 0.03846197
Iteration 206, loss = 0.03928284
Iteration 207, loss = 0.03684463
Iteration 208, loss = 0.04080487
Iteration 209, loss = 0.03825561
Iteration 210, loss = 0.04048340
Iteration 211, loss = 0.03718190
Iteration 212, loss = 0.03834240
Iteration 213, loss = 0.03681161
Iteration 214, loss = 0.03962001
Iteration 215, loss = 0.03493672
Iteration 216, loss = 0.04217443
Iteration 217, loss = 0.03625023
Iteration 218, loss = 0.03953639
Iteration 219, loss = 0.03981351
Iteration 220, loss = 0.03529126
Iteration 221, loss = 0.04245836
Iteration 222, loss = 0.03454842
Iteration 223, loss = 0.03987233
Iteration 224, loss = 0.03888618
Iteration 225, loss = 0.03692449
Iteration 226, loss = 0.03563451
Iteration 227, loss = 0.03523953
Iteration 228, loss = 0.03649142
Iteration 229, loss = 0.03443686
Iteration 230, loss = 0.03607212
Iteration 231, loss = 0.03580163
Iteration 232, loss = 0.03443691
Iteration 233, loss = 0.03496512
Iteration 234, loss = 0.03410269
Iteration 235, loss = 0.03486721
Iteration 236, loss = 0.03494995
Iteration 237, loss = 0.03442324
Iteration 238, loss = 0.03354623
Iteration 239, loss = 0.03392976
Iteration 240, loss = 0.03322149
Iteration 241, loss = 0.03382724
Iteration 242, loss = 0.03345936
Iteration 243, loss = 0.03408320
Iteration 244, loss = 0.03305371
Iteration 245, loss = 0.03503546
Iteration 246, loss = 0.03262591
Iteration 247, loss = 0.03547139
Iteration 248, loss = 0.03426711
Iteration 249, loss = 0.03324912
Iteration 250, loss = 0.03335471
Iteration 251, loss = 0.03309902
Iteration 252, loss = 0.03262938
Iteration 253, loss = 0.03201906
Iteration 254, loss = 0.03194995
Iteration 255, loss = 0.03199622
Iteration 256, loss = 0.03186006
Iteration 257, loss = 0.03160458
Iteration 258, loss = 0.03152096
Iteration 259, loss = 0.03193160
Iteration 260, loss = 0.03182877
Iteration 261, loss = 0.03156277
Iteration 262, loss = 0.03202340
Iteration 263, loss = 0.03105467
Iteration 264, loss = 0.03273810
Iteration 265, loss = 0.03157585
Iteration 266, loss = 0.03237138
Iteration 267, loss = 0.03212090
Iteration 268, loss = 0.03149497
Iteration 269, loss = 0.03097204
Iteration 270, loss = 0.03094250
Iteration 271, loss = 0.03074140
Iteration 272, loss = 0.03126759
Iteration 273, loss = 0.03112599
Iteration 274, loss = 0.03042925
Iteration 275, loss = 0.03028926
Iteration 276, loss = 0.03146109
Iteration 277, loss = 0.03220930
Iteration 278, loss = 0.03148929
Iteration 279, loss = 0.03052070
Iteration 280, loss = 0.03023740
Iteration 281, loss = 0.03251707
Iteration 282, loss = 0.03059582
Iteration 283, loss = 0.03235795
Iteration 284, loss = 0.03143172
Iteration 285, loss = 0.03166954
Iteration 286, loss = 0.03138812
Iteration 287, loss = 0.03118300
Iteration 288, loss = 0.02979163
Iteration 289, loss = 0.02974264
Iteration 290, loss = 0.03057910
Iteration 291, loss = 0.03063294
Iteration 292, loss = 0.02946351
Iteration 293, loss = 0.03427413
Iteration 294, loss = 0.03402446
Iteration 295, loss = 0.03440615
Iteration 296, loss = 0.03046357
Iteration 297, loss = 0.03418130
Iteration 298, loss = 0.02943987
Iteration 299, loss = 0.03218883
Iteration 300, loss = 0.02932240
Iteration 301, loss = 0.03196171
Iteration 302, loss = 0.02989416
Iteration 303, loss = 0.02838222
Iteration 304, loss = 0.03066870
Iteration 305, loss = 0.02920159
Iteration 306, loss = 0.02986607
Iteration 307, loss = 0.02939954
Iteration 308, loss = 0.02926236
Iteration 309, loss = 0.03025151
Iteration 310, loss = 0.03096231
Iteration 311, loss = 0.02892671
Iteration 312, loss = 0.02886374
Iteration 313, loss = 0.02829355
Iteration 314, loss = 0.02911097
Iteration 315, loss = 0.02937925
Iteration 316, loss = 0.02858503
Iteration 317, loss = 0.02925470
Iteration 318, loss = 0.02807792
Iteration 319, loss = 0.03000487
Iteration 320, loss = 0.02825422
Iteration 321, loss = 0.02784672
Iteration 322, loss = 0.02835654
Iteration 323, loss = 0.02867597
Iteration 324, loss = 0.02783788
Iteration 325, loss = 0.02795250
Iteration 326, loss = 0.02866682
Iteration 327, loss = 0.02901717
Iteration 328, loss = 0.02741874
Iteration 329, loss = 0.02834369
Iteration 330, loss = 0.02781455
Iteration 331, loss = 0.02948541
Iteration 332, loss = 0.02731836
Iteration 333, loss = 0.02825286
Iteration 334, loss = 0.02746507
Iteration 335, loss = 0.02771283
Iteration 336, loss = 0.02781596
Iteration 337, loss = 0.02714046
Iteration 338, loss = 0.02822338
Iteration 339, loss = 0.02836976
Iteration 340, loss = 0.02793114
Iteration 341, loss = 0.02732982
Iteration 342, loss = 0.02942907
Iteration 343, loss = 0.02794615
Iteration 344, loss = 0.03108115
Iteration 345, loss = 0.02639610
Iteration 346, loss = 0.03043196
Iteration 347, loss = 0.02802401
Iteration 348, loss = 0.02738186
Iteration 349, loss = 0.02872088
Iteration 350, loss = 0.02663652
Iteration 351, loss = 0.02851428
Iteration 352, loss = 0.02743700
Iteration 353, loss = 0.02766289
Iteration 354, loss = 0.02670888
Iteration 355, loss = 0.02683528
Iteration 356, loss = 0.02870106
Training loss did not improve more than tol=0.000010 for 10 consecutive epochs. Stopping.
Iteration 1, loss = 1.10536716
Iteration 2, loss = 1.09911887
Iteration 3, loss = 1.09631957
Iteration 4, loss = 1.09263584
Iteration 5, loss = 1.08920416
Iteration 6, loss = 1.08539667
Iteration 7, loss = 1.08169519
Iteration 8, loss = 1.07753129
Iteration 9, loss = 1.07203286
Iteration 10, loss = 1.06700402
Iteration 11, loss = 1.05996113
Iteration 12, loss = 1.05161982
Iteration 13, loss = 1.04239607
Iteration 14, loss = 1.03038858
Iteration 15, loss = 1.01608108
Iteration 16, loss = 0.99863310
Iteration 17, loss = 0.97844175
Iteration 18, loss = 0.95535804
Iteration 19, loss = 0.92873116
Iteration 20, loss = 0.89855895
Iteration 21, loss = 0.86564066
Iteration 22, loss = 0.83077930
Iteration 23, loss = 0.79521640
Iteration 24, loss = 0.75961296
Iteration 25, loss = 0.72493210
Iteration 26, loss = 0.69255835
Iteration 27, loss = 0.66340224
Iteration 28, loss = 0.63757572
Iteration 29, loss = 0.61452562
Iteration 30, loss = 0.59538407
Iteration 31, loss = 0.57908394
Iteration 32, loss = 0.56525021
Iteration 33, loss = 0.55313001
Iteration 34, loss = 0.54260757
Iteration 35, loss = 0.53346743
Iteration 36, loss = 0.52626206
Iteration 37, loss = 0.51795922
Iteration 38, loss = 0.51050504
Iteration 39, loss = 0.50408565
Iteration 40, loss = 0.49746745
Iteration 41, loss = 0.48885766
Iteration 42, loss = 0.48098723
Iteration 43, loss = 0.47305551
Iteration 44, loss = 0.46498620
Iteration 45, loss = 0.45579055
Iteration 46, loss = 0.44510167
Iteration 47, loss = 0.43437895
Iteration 48, loss = 0.42293651
Iteration 49, loss = 0.41123063
Iteration 50, loss = 0.39897065
Iteration 51, loss = 0.38672376
Iteration 52, loss = 0.37493509
Iteration 53, loss = 0.36323524
Iteration 54, loss = 0.35383802
Iteration 55, loss = 0.34131197
Iteration 56, loss = 0.33009311
Iteration 57, loss = 0.32144138
Iteration 58, loss = 0.31019277
Iteration 59, loss = 0.30102917
Iteration 60, loss = 0.29121001
Iteration 61, loss = 0.28189736
Iteration 62, loss = 0.27306855
Iteration 63, loss = 0.26670328
Iteration 64, loss = 0.25841831
Iteration 65, loss = 0.25101450
Iteration 66, loss = 0.24438306
Iteration 67, loss = 0.23891216
Iteration 68, loss = 0.23240669
Iteration 69, loss = 0.22719311
Iteration 70, loss = 0.22173641
Iteration 71, loss = 0.21686193
Iteration 72, loss = 0.21246021
Iteration 73, loss = 0.20785761
Iteration 74, loss = 0.20400540
Iteration 75, loss = 0.20008576
Iteration 76, loss = 0.19696077
Iteration 77, loss = 0.19243471
Iteration 78, loss = 0.18856881
Iteration 79, loss = 0.18623010
Iteration 80, loss = 0.18318422
Iteration 81, loss = 0.17975648
Iteration 82, loss = 0.17642576
Iteration 83, loss = 0.17391065
Iteration 84, loss = 0.17072851
Iteration 85, loss = 0.16902291
Iteration 86, loss = 0.16522852
Iteration 87, loss = 0.16228312
Iteration 88, loss = 0.15974247
Iteration 89, loss = 0.15690254
Iteration 90, loss = 0.15435642
Iteration 91, loss = 0.15179321
Iteration 92, loss = 0.14926742
Iteration 93, loss = 0.14672147
Iteration 94, loss = 0.14495323
Iteration 95, loss = 0.14249179
Iteration 96, loss = 0.14075046
Iteration 97, loss = 0.13766380
Iteration 98, loss = 0.13728934
Iteration 99, loss = 0.13393912
Iteration 100, loss = 0.13237574
Iteration 101, loss = 0.13055777
Iteration 102, loss = 0.12755479
Iteration 103, loss = 0.12660605
Iteration 104, loss = 0.12378435
Iteration 105, loss = 0.12322863
Iteration 106, loss = 0.12186063
Iteration 107, loss = 0.11912164
Iteration 108, loss = 0.11698933
Iteration 109, loss = 0.11528220
Iteration 110, loss = 0.11397101
Iteration 111, loss = 0.11221635
Iteration 112, loss = 0.11070660
Iteration 113, loss = 0.10894642
Iteration 114, loss = 0.10727076
Iteration 115, loss = 0.10688238
Iteration 116, loss = 0.10463343
Iteration 117, loss = 0.10367966
Iteration 118, loss = 0.10376294
Iteration 119, loss = 0.10148926
Iteration 120, loss = 0.09957627
Iteration 121, loss = 0.09944371
Iteration 122, loss = 0.09678294
Iteration 123, loss = 0.09589823
Iteration 124, loss = 0.09531050
Iteration 125, loss = 0.09341965
Iteration 126, loss = 0.09288017
Iteration 127, loss = 0.09159934
Iteration 128, loss = 0.09113278
Iteration 129, loss = 0.08955358
Iteration 130, loss = 0.08906920
Iteration 131, loss = 0.08680309
Iteration 132, loss = 0.08795600
Iteration 133, loss = 0.08491095
Iteration 134, loss = 0.08402154
Iteration 135, loss = 0.08566612
Iteration 136, loss = 0.08320641
Iteration 137, loss = 0.08088876
Iteration 138, loss = 0.08322675
Iteration 139, loss = 0.08101699
Iteration 140, loss = 0.07965292
Iteration 141, loss = 0.08000337
Iteration 142, loss = 0.07807099
Iteration 143, loss = 0.07777109
Iteration 144, loss = 0.07695202
Iteration 145, loss = 0.07487739
Iteration 146, loss = 0.07504602
Iteration 147, loss = 0.07398679
Iteration 148, loss = 0.07315028
Iteration 149, loss = 0.07313848
Iteration 150, loss = 0.07153573
Iteration 151, loss = 0.07166257
Iteration 152, loss = 0.07184280
Iteration 153, loss = 0.06984760
Iteration 154, loss = 0.06895020
Iteration 155, loss = 0.06950959
Iteration 156, loss = 0.06787900
Iteration 157, loss = 0.06802656
Iteration 158, loss = 0.06865310
Iteration 159, loss = 0.06688655
Iteration 160, loss = 0.06613017
Iteration 161, loss = 0.06538064
Iteration 162, loss = 0.06552816
Iteration 163, loss = 0.06499040
Iteration 164, loss = 0.06362236
Iteration 165, loss = 0.06346180
Iteration 166, loss = 0.06344907
Iteration 167, loss = 0.06318969
Iteration 168, loss = 0.06201818
Iteration 169, loss = 0.06118498
Iteration 170, loss = 0.06245842
Iteration 171, loss = 0.06102090
Iteration 172, loss = 0.06104149
Iteration 173, loss = 0.06078976
Iteration 174, loss = 0.05913164
Iteration 175, loss = 0.05891887
Iteration 176, loss = 0.05898621
Iteration 177, loss = 0.05830156
Iteration 178, loss = 0.05771642
Iteration 179, loss = 0.05775128
Iteration 180, loss = 0.05697271
Iteration 181, loss = 0.05627003
Iteration 182, loss = 0.05596147
Iteration 183, loss = 0.05657054
Iteration 184, loss = 0.05579489
Iteration 185, loss = 0.05673556
Iteration 186, loss = 0.05600965
Iteration 187, loss = 0.05429570
Iteration 188, loss = 0.05426364
Iteration 189, loss = 0.05375037
Iteration 190, loss = 0.05324218
Iteration 191, loss = 0.05312475
Iteration 192, loss = 0.05299477
Iteration 193, loss = 0.05329788
Iteration 194, loss = 0.05265555
Iteration 195, loss = 0.05202408
Iteration 196, loss = 0.05212915
Iteration 197, loss = 0.05116762
Iteration 198, loss = 0.05199270
Iteration 199, loss = 0.05152691
Iteration 200, loss = 0.05037148
Iteration 201, loss = 0.05034700
Iteration 202, loss = 0.05050464
Iteration 203, loss = 0.04947561
Iteration 204, loss = 0.04943739
Iteration 205, loss = 0.05017795
Iteration 206, loss = 0.04953321
Iteration 207, loss = 0.04832064
Iteration 208, loss = 0.05140568
Iteration 209, loss = 0.05000366
Iteration 210, loss = 0.04836253
Iteration 211, loss = 0.04894660
Iteration 212, loss = 0.04770450
Iteration 213, loss = 0.04718192
Iteration 214, loss = 0.04777924
Iteration 215, loss = 0.04756983
Iteration 216, loss = 0.04739568
Iteration 217, loss = 0.04636777
Iteration 218, loss = 0.04646126
Iteration 219, loss = 0.04692349
Iteration 220, loss = 0.04599256
Iteration 221, loss = 0.04576437
Iteration 222, loss = 0.04610548
Iteration 223, loss = 0.04657359
Iteration 224, loss = 0.04601679
Iteration 225, loss = 0.04500225
Iteration 226, loss = 0.04447787
Iteration 227, loss = 0.04564586
Iteration 228, loss = 0.04532282
Iteration 229, loss = 0.04449398
Iteration 230, loss = 0.04393654
Iteration 231, loss = 0.04365707
Iteration 232, loss = 0.04445302
Iteration 233, loss = 0.04448499
Iteration 234, loss = 0.04411013
Iteration 235, loss = 0.04301795
Iteration 236, loss = 0.04299848
Iteration 237, loss = 0.04393297
Iteration 238, loss = 0.04357882
Iteration 239, loss = 0.04235991
Iteration 240, loss = 0.04236133
Iteration 241, loss = 0.04314703
Iteration 242, loss = 0.04292836
Iteration 243, loss = 0.04155843
Iteration 244, loss = 0.04210280
Iteration 245, loss = 0.04357282
Iteration 246, loss = 0.04236537
Iteration 247, loss = 0.04248027
Iteration 248, loss = 0.04191168
Iteration 249, loss = 0.04099678
Iteration 250, loss = 0.04106207
Iteration 251, loss = 0.04093895
Iteration 252, loss = 0.04097404
Iteration 253, loss = 0.04045465
Iteration 254, loss = 0.04013700
Iteration 255, loss = 0.04015752
Iteration 256, loss = 0.04007391
Iteration 257, loss = 0.03988114
Iteration 258, loss = 0.03947016
Iteration 259, loss = 0.03992071
Iteration 260, loss = 0.04029131
Iteration 261, loss = 0.03959336
Iteration 262, loss = 0.03896725
Iteration 263, loss = 0.03934763
Iteration 264, loss = 0.03991300
Iteration 265, loss = 0.03915924
Iteration 266, loss = 0.03915129
Iteration 267, loss = 0.03943894
Iteration 268, loss = 0.03909193
Iteration 269, loss = 0.03826966
Iteration 270, loss = 0.03819068
Iteration 271, loss = 0.03820745
Iteration 272, loss = 0.03817653
Iteration 273, loss = 0.03792761
Iteration 274, loss = 0.03776703
Iteration 275, loss = 0.03819751
Iteration 276, loss = 0.03816901
Iteration 277, loss = 0.03773718
Iteration 278, loss = 0.03743688
Iteration 279, loss = 0.03749462
Iteration 280, loss = 0.03722778
Iteration 281, loss = 0.03823781
Iteration 282, loss = 0.03766144
Iteration 283, loss = 0.03662729
Iteration 284, loss = 0.03782236
Iteration 285, loss = 0.03794631
Iteration 286, loss = 0.03706229
Iteration 287, loss = 0.03689131
Iteration 288, loss = 0.03639492
Iteration 289, loss = 0.03651078
Iteration 290, loss = 0.03648957
Iteration 291, loss = 0.03591981
Iteration 292, loss = 0.03618059
Iteration 293, loss = 0.03711442
Iteration 294, loss = 0.03745676
Iteration 295, loss = 0.03678720
Iteration 296, loss = 0.03560854
Iteration 297, loss = 0.03582590
Iteration 298, loss = 0.03663023
Iteration 299, loss = 0.03558718
Iteration 300, loss = 0.03566802
Iteration 301, loss = 0.03603003
Iteration 302, loss = 0.03489017
Iteration 303, loss = 0.03625104
Iteration 304, loss = 0.03652355
Iteration 305, loss = 0.03621105
Iteration 306, loss = 0.03752043
Iteration 307, loss = 0.03567398
Iteration 308, loss = 0.03548423
Iteration 309, loss = 0.03559367
Iteration 310, loss = 0.03436125
Iteration 311, loss = 0.03456362
Iteration 312, loss = 0.03514295
Iteration 313, loss = 0.03422361
Iteration 314, loss = 0.03493281
Iteration 315, loss = 0.03526269
Iteration 316, loss = 0.03458594
Iteration 317, loss = 0.03431797
Iteration 318, loss = 0.03377010
Iteration 319, loss = 0.03342700
Iteration 320, loss = 0.03550168
Iteration 321, loss = 0.03530742
Iteration 322, loss = 0.03318150
Iteration 323, loss = 0.03413828
Iteration 324, loss = 0.03448195
Iteration 325, loss = 0.03327754
Iteration 326, loss = 0.03387927
Iteration 327, loss = 0.03433323
Iteration 328, loss = 0.03322132
Iteration 329, loss = 0.03360294
Iteration 330, loss = 0.03310847
Iteration 331, loss = 0.03287708
Iteration 332, loss = 0.03287324
Iteration 333, loss = 0.03282966
Iteration 334, loss = 0.03325183
Iteration 335, loss = 0.03261588
Iteration 336, loss = 0.03302954
Iteration 337, loss = 0.03284878
Iteration 338, loss = 0.03271659
Iteration 339, loss = 0.03416872
Iteration 340, loss = 0.03293975
Iteration 341, loss = 0.03309722
Iteration 342, loss = 0.03282359
Iteration 343, loss = 0.03232172
Iteration 344, loss = 0.03241695
Iteration 345, loss = 0.03242631
Iteration 346, loss = 0.03240164
Iteration 347, loss = 0.03178789
Iteration 348, loss = 0.03167032
Iteration 349, loss = 0.03154991
Iteration 350, loss = 0.03196703
Iteration 351, loss = 0.03185616
Iteration 352, loss = 0.03163880
Iteration 353, loss = 0.03212148
Iteration 354, loss = 0.03143976
Iteration 355, loss = 0.03132081
Iteration 356, loss = 0.03255335
Iteration 357, loss = 0.03161607
Iteration 358, loss = 0.03213289
Iteration 359, loss = 0.03193832
Iteration 360, loss = 0.03181001
Iteration 361, loss = 0.03100926
Iteration 362, loss = 0.03127333
Iteration 363, loss = 0.03177893
Iteration 364, loss = 0.03104060
Iteration 365, loss = 0.03085245
Iteration 366, loss = 0.03114274
Iteration 367, loss = 0.03094688
Iteration 368, loss = 0.03084898
Iteration 369, loss = 0.03074478
Iteration 370, loss = 0.03053138
Iteration 371, loss = 0.03042573
Iteration 372, loss = 0.03133216
Iteration 373, loss = 0.03161829
Iteration 374, loss = 0.03056764
Iteration 375, loss = 0.03045008
Iteration 376, loss = 0.03042129
Iteration 377, loss = 0.03019682
Iteration 378, loss = 0.03027087
Iteration 379, loss = 0.03075339
Iteration 380, loss = 0.03031756
Iteration 381, loss = 0.03027485
Iteration 382, loss = 0.03062268
Iteration 383, loss = 0.03018576
Iteration 384, loss = 0.02979066
Iteration 385, loss = 0.02985447
Iteration 386, loss = 0.02974446
Iteration 387, loss = 0.02965097
Iteration 388, loss = 0.02944200
Iteration 389, loss = 0.02945165
Iteration 390, loss = 0.02961255
Iteration 391, loss = 0.02962599
Iteration 392, loss = 0.02948830
Iteration 393, loss = 0.03046067
Iteration 394, loss = 0.02895930
Iteration 395, loss = 0.02986854
Iteration 396, loss = 0.03047394
Iteration 397, loss = 0.02960296
Iteration 398, loss = 0.02962237
Iteration 399, loss = 0.02945466
Iteration 400, loss = 0.02913107
Iteration 401, loss = 0.02962762
Iteration 402, loss = 0.02993103
Iteration 403, loss = 0.02891274
Iteration 404, loss = 0.02883173
Iteration 405, loss = 0.02892492
Iteration 406, loss = 0.02868801
Iteration 407, loss = 0.02873189
Iteration 408, loss = 0.02930036
Iteration 409, loss = 0.02883268
Iteration 410, loss = 0.02892511
Iteration 411, loss = 0.02908070
Iteration 412, loss = 0.02850592
Iteration 413, loss = 0.02858651
Iteration 414, loss = 0.02904827
Iteration 415, loss = 0.02875363
Iteration 416, loss = 0.02896928
Iteration 417, loss = 0.02863797
Iteration 418, loss = 0.02806877
Iteration 419, loss = 0.02873619
Iteration 420, loss = 0.02904473
Iteration 421, loss = 0.02917869
Iteration 422, loss = 0.02869736
Iteration 423, loss = 0.02791520
Iteration 424, loss = 0.02837776
Iteration 425, loss = 0.02983507
Iteration 426, loss = 0.02865848
Iteration 427, loss = 0.02818893
Iteration 428, loss = 0.02874395
Iteration 429, loss = 0.02935794
Iteration 430, loss = 0.02783959
Iteration 431, loss = 0.02897531
Iteration 432, loss = 0.02759141
Iteration 433, loss = 0.02885594
Iteration 434, loss = 0.02790867
Iteration 435, loss = 0.02798174
Iteration 436, loss = 0.02918370
Iteration 437, loss = 0.02793105
Iteration 438, loss = 0.02791056
Iteration 439, loss = 0.02857694
Iteration 440, loss = 0.02794700
Iteration 441, loss = 0.02715548
Iteration 442, loss = 0.02878870
Iteration 443, loss = 0.02796239
Iteration 444, loss = 0.02757207
Iteration 445, loss = 0.02843570
Iteration 446, loss = 0.02793195
Iteration 447, loss = 0.02760334
Iteration 448, loss = 0.02749620
Iteration 449, loss = 0.02734685
Iteration 450, loss = 0.02717846
Iteration 451, loss = 0.02735258
Iteration 452, loss = 0.02770746
Training loss did not improve more than tol=0.000010 for 10 consecutive epochs. Stopping.
Iteration 1, loss = 1.48349301
Iteration 2, loss = 1.29059958
Iteration 3, loss = 1.13979825
Iteration 4, loss = 1.02547742
Iteration 5, loss = 0.95273174
Iteration 6, loss = 0.90861575
Iteration 7, loss = 0.87531058
Iteration 8, loss = 0.84870379
Iteration 9, loss = 0.81639835
Iteration 10, loss = 0.78465781
Iteration 11, loss = 0.75412869
Iteration 12, loss = 0.72693230
Iteration 13, loss = 0.70124912
Iteration 14, loss = 0.67690075
Iteration 15, loss = 0.65379545
Iteration 16, loss = 0.63132532
Iteration 17, loss = 0.60905653
Iteration 18, loss = 0.58688368
Iteration 19, loss = 0.56571784
Iteration 20, loss = 0.54571984
Iteration 21, loss = 0.52620167
Iteration 22, loss = 0.50725579
Iteration 23, loss = 0.48792451
Iteration 24, loss = 0.47004306
Iteration 25, loss = 0.45018240
Iteration 26, loss = 0.43083582
Iteration 27, loss = 0.41171118
Iteration 28, loss = 0.39192309
Iteration 29, loss = 0.37280850
Iteration 30, loss = 0.35490150
Iteration 31, loss = 0.33647736
Iteration 32, loss = 0.31717253
Iteration 33, loss = 0.30025108
Iteration 34, loss = 0.28387349
Iteration 35, loss = 0.26819784
Iteration 36, loss = 0.25214835
Iteration 37, loss = 0.23774385
Iteration 38, loss = 0.22380992
Iteration 39, loss = 0.21042508
Iteration 40, loss = 0.20014124
Iteration 41, loss = 0.19027094
Iteration 42, loss = 0.17567251
Iteration 43, loss = 0.17059308
Iteration 44, loss = 0.15788307
Iteration 45, loss = 0.14996448
Iteration 46, loss = 0.14255334
Iteration 47, loss = 0.13504334
Iteration 48, loss = 0.12748962
Iteration 49, loss = 0.12029411
Iteration 50, loss = 0.11424297
Iteration 51, loss = 0.11021416
Iteration 52, loss = 0.10334078
Iteration 53, loss = 0.10250335
Iteration 54, loss = 0.09474576
Iteration 55, loss = 0.09511639
Iteration 56, loss = 0.08735512
Iteration 57, loss = 0.08665733
Iteration 58, loss = 0.08286524
Iteration 59, loss = 0.08065522
Iteration 60, loss = 0.07675260
Iteration 61, loss = 0.07626307
Iteration 62, loss = 0.07357395
Iteration 63, loss = 0.07332010
Iteration 64, loss = 0.06950084
Iteration 65, loss = 0.06580592
Iteration 66, loss = 0.06589217
Iteration 67, loss = 0.06351581
Iteration 68, loss = 0.06342578
Iteration 69, loss = 0.06110174
Iteration 70, loss = 0.05920025
Iteration 71, loss = 0.05894847
Iteration 72, loss = 0.05731473
Iteration 73, loss = 0.05726125
Iteration 74, loss = 0.05536146
Iteration 75, loss = 0.05426702
Iteration 76, loss = 0.05365581
Iteration 77, loss = 0.05224619
Iteration 78, loss = 0.05247424
Iteration 79, loss = 0.05358890
Iteration 80, loss = 0.04992909
Iteration 81, loss = 0.05151232
Iteration 82, loss = 0.05040136
Iteration 83, loss = 0.04787959
Iteration 84, loss = 0.04883296
Iteration 85, loss = 0.04704942
Iteration 86, loss = 0.04705942
Iteration 87, loss = 0.04621283
Iteration 88, loss = 0.04581161
Iteration 89, loss = 0.04479233
Iteration 90, loss = 0.04433262
Iteration 91, loss = 0.04386103
Iteration 92, loss = 0.04301356
Iteration 93, loss = 0.04306319
Iteration 94, loss = 0.04237611
Iteration 95, loss = 0.04265343
Iteration 96, loss = 0.04227202
Iteration 97, loss = 0.04208176
Iteration 98, loss = 0.04044931
Iteration 99, loss = 0.04153470
Iteration 100, loss = 0.03999329
Iteration 101, loss = 0.04427756
Iteration 102, loss = 0.04262637
Iteration 103, loss = 0.04145404
Iteration 104, loss = 0.04032412
Iteration 105, loss = 0.03841529
Iteration 106, loss = 0.03933116
Iteration 107, loss = 0.03764824
Iteration 108, loss = 0.03880789
Iteration 109, loss = 0.03619472
Iteration 110, loss = 0.03876754
Iteration 111, loss = 0.03757661
Iteration 112, loss = 0.03692172
Iteration 113, loss = 0.03594608
Iteration 114, loss = 0.03669212
Iteration 115, loss = 0.03667825
Iteration 116, loss = 0.03506927
Iteration 117, loss = 0.03920240
Iteration 118, loss = 0.03513635
Iteration 119, loss = 0.03698207
Iteration 120, loss = 0.03402020
Iteration 121, loss = 0.03843519
Iteration 122, loss = 0.03465211
Iteration 123, loss = 0.03556149
Iteration 124, loss = 0.03393941
Iteration 125, loss = 0.03623580
Iteration 126, loss = 0.03583930
Iteration 127, loss = 0.03323907
Iteration 128, loss = 0.03596437
Iteration 129, loss = 0.03354972
Iteration 130, loss = 0.03598168
Iteration 131, loss = 0.03139347
Iteration 132, loss = 0.03821530
Iteration 133, loss = 0.03000704
Iteration 134, loss = 0.04335580
Iteration 135, loss = 0.03205963
Iteration 136, loss = 0.03287418
Iteration 137, loss = 0.03212859
Iteration 138, loss = 0.03057641
Iteration 139, loss = 0.03100990
Iteration 140, loss = 0.02974870
Iteration 141, loss = 0.03015579
Iteration 142, loss = 0.03023411
Iteration 143, loss = 0.02920905
Iteration 144, loss = 0.03021651
Iteration 145, loss = 0.03039958
Iteration 146, loss = 0.02997154
Iteration 147, loss = 0.02867848
Iteration 148, loss = 0.02965287
Iteration 149, loss = 0.02887706
Iteration 150, loss = 0.02939394
Iteration 151, loss = 0.02839319
Iteration 152, loss = 0.02894102
Iteration 153, loss = 0.02765445
Iteration 154, loss = 0.02886670
Iteration 155, loss = 0.02854749
Iteration 156, loss = 0.02959063
Iteration 157, loss = 0.03017912
Iteration 158, loss = 0.03281065
Iteration 159, loss = 0.02977579
Iteration 160, loss = 0.02950974
Iteration 161, loss = 0.02832580
Iteration 162, loss = 0.02904101
Iteration 163, loss = 0.02802389
Iteration 164, loss = 0.02767712
Training loss did not improve more than tol=0.000010 for 10 consecutive epochs. Stopping.
Out[5]:
MLPClassifier(activation='relu', alpha=0.01, batch_size='auto', beta_1=0.9,
       beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=(10, 10), learning_rate='constant',
       learning_rate_init=0.01, max_iter=1000, momentum=0.9,
       n_iter_no_change=10, nesterovs_momentum=True, power_t=0.5,
       random_state=1, shuffle=True, solver='adam', tol=1e-05,
       validation_fraction=0.1, verbose=True, warm_start=False)

学習回数が1000回と指定されていますが、評価値が向上しなくなっているので、学習が打ち切りになっています。打ち切りたくない場合は、tolを変更しましょう。

学習済みモデルによる推定

SVMや決定木と同様に、NNもpredictを使用することで、クラス分類の結果を得ることができます。また、predict_probaを使用することで、NNの出力ニューロンの値を得ることができます。

In [6]:
Res_Class = NNmodel_Lin.predict(X)
Res_OutputValue = NNmodel_Lin.predict_proba(X)

print("0番目の分類結果: ", Res_Class[0])
print("0番目の入力に対する出力値: ", Res_OutputValue[0])
0番目の分類結果:  Good
0番目の入力に対する出力値:  [3.08477817e-12 9.90281162e-01 9.71883797e-03]

Xの0行目を入力したとき、Goodと推定されることがわかりました。また、出力ニューロンは、Bad, Good, Middleの確信度を表しています(アルファベットでクラスがソートされています)。2つ目の出力ニューロンの値が一番大きいので、Goodクラスに判定された、ということになります。まとめてXを推定するのではなく、新たなデータを推定させるには、以下のように記入します。

In [7]:
# 勉強時間が5時間で、理解度が4の推定は?
print(NNmodel_Lin.predict([[5, 4]]))
print(NNmodel_Sigm.predict([[5, 4]]))
print(NNmodel_ReLu.predict([[5, 4]]))
['Middle']
['Middle']
['Middle']

5時間勉強し、理解傾向が4の場合は、どの活性化関数を使用したモデルも、成績はきっと中間くらいですよと推定されることがわかりました。

特徴量空間の可視化

それでは、いつも通り特徴量空間を可視化して見ます。説明は決定木のときにしているので省略します。

In [9]:
# メッシュデータ生成
Xmin, Ymin, Xmax, Ymax = 0, 0, 10, 10 # 空間の最小最大値
resolution = 0.1 # 細かさ
x_mesh, y_mesh = np.meshgrid(np.arange(Xmin, Xmax, resolution),
                             np.arange(Ymin, Ymax, resolution))
MeshDat = np.array([x_mesh.ravel(), y_mesh.ravel()]).T

# メッシュデータの推定
z_Lin = NNmodel_Lin.predict(MeshDat) # 線形
z_Sigm = NNmodel_Sigm.predict(MeshDat) #  シグモイド関数
z_ReLu = NNmodel_ReLu.predict(MeshDat) # ReLu関数

# データ整形
z_Lin = np.reshape(z_Lin, (len(x_mesh), len(y_mesh))) 
z_Sigm = np.reshape(z_Sigm, (len(x_mesh), len(y_mesh))) 
z_ReLu = np.reshape(z_ReLu, (len(x_mesh), len(y_mesh))) 

# 可視化
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(12, 4))

# *** Lin Model ***
plt.subplot(1,3,1)

# NN 出力
plt.scatter(x_mesh[z_Lin=='Bad'], y_mesh[z_Lin=='Bad'], s=5, alpha=0.3, c='red')
plt.scatter(x_mesh[z_Lin=='Middle'], y_mesh[z_Lin=='Middle'], s=5, alpha=0.3, c='yellow')
plt.scatter(x_mesh[z_Lin=='Good'], y_mesh[z_Lin=='Good'], s=5, alpha=0.3, c='blue')
# 教師データ(Bad, Middle, Good)
plt.scatter(Dat03[:,0], Dat03[:,1], s=50, c='red', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')
plt.scatter(Dat02[:,0], Dat02[:,1], s=50, c='orange', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')
plt.scatter(Dat01[:,0], Dat01[:,1], s=50, c='blue', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')
plt.title("Feature space by NN Lin Model")
plt.xlabel("X1: Studying Time")
plt.ylabel("X2: Understanding")
plt.grid(True)
plt.xlim(0, 10)
plt.ylim(0, 10)

# *** Sigm Model ***
plt.subplot(1,3,2)

# NN 出力
plt.scatter(x_mesh[z_Sigm=='Bad'], y_mesh[z_Sigm=='Bad'], s=5, alpha=0.3, c='red')
plt.scatter(x_mesh[z_Sigm=='Middle'], y_mesh[z_Sigm=='Middle'], s=5, alpha=0.3, c='yellow')
plt.scatter(x_mesh[z_Sigm=='Good'], y_mesh[z_Sigm=='Good'], s=5, alpha=0.3, c='blue')
# 教師データ(Bad, Middle, Good)
plt.scatter(Dat03[:,0], Dat03[:,1], s=50, c='red', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')
plt.scatter(Dat02[:,0], Dat02[:,1], s=50, c='orange', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')
plt.scatter(Dat01[:,0], Dat01[:,1], s=50, c='blue', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')
plt.title("Feature space by NN Sigm Model")
plt.xlabel("X1: Studying Time")
plt.ylabel("X2: Understanding")
plt.grid(True)
plt.xlim(0, 10)
plt.ylim(0, 10)

# *** Sigm Model ***
plt.subplot(1,3,3)

# NN 出力
plt.scatter(x_mesh[z_ReLu=='Bad'], y_mesh[z_ReLu=='Bad'], s=5, alpha=0.3, c='red')
plt.scatter(x_mesh[z_ReLu=='Middle'], y_mesh[z_ReLu=='Middle'], s=5, alpha=0.3, c='yellow')
plt.scatter(x_mesh[z_ReLu=='Good'], y_mesh[z_ReLu=='Good'], s=5, alpha=0.3, c='blue')
# 教師データ(Bad, Middle, Good)
plt.scatter(Dat03[:,0], Dat03[:,1], s=50, c='red', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')
plt.scatter(Dat02[:,0], Dat02[:,1], s=50, c='orange', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')
plt.scatter(Dat01[:,0], Dat01[:,1], s=50, c='blue', marker='s', 
            alpha=0.8, linewidths=0.5, edgecolors='black')
plt.title("Feature space by NN ReLu Model")
plt.xlabel("X1: Studying Time")
plt.ylabel("X2: Understanding")
plt.grid(True)
plt.xlim(0, 10)
plt.ylim(0, 10)

plt.legend(["Bad", "Middle", "Good",
            "Bad (TrainData)", "Middle (TrainData)", "Good (TrainData)"], 
           loc="upper right", bbox_to_anchor=(1.5, 1))
Out[9]:
<matplotlib.legend.Legend at 0x1114b8470>

どの活性化関数を使用したNNも、比較的似た出力を行なっています。いずれも、「勉強をたくさん行い、理解傾向も高いと成績も良い」、「勉強だけして理解傾向が低いと、成績は中くらい」、「勉強をしないと理解傾向にあまり依存せず成績が低い」と推定する空間を得られていることが確認できます。一番左、線形のモデルは先ほど説明した通り識別境界が単純な直線になっていることがわかります。真ん中、シグモイド関数を使用したモデルは識別境界が曲がっています。この差は、SVMにおけるカーネル関数(線形カーネル、RBFカーネル)の差と一緒ですね。最後に、一番右のReLu関数を使用したモデルは、直線を細かく区切って識別境界を生成していることがわかります。

今回は、入力層2次元(勉強時間/理解傾向)、隠れ層10次元、隠れ層10次元、出力層3次元(Bad/Middle/Good)の4層階層型NNの構築を行いました。入力の次元が2次元ですので、可視化できましたが、実際には入力次元はもっと広いことが多く、問題も複雑です。そのため、今回のような単純な問題とは異なり、設定するハイパーパラメータによって、人工知能の個性がばんばん出てきます。ぜひ、良いモデルを目指せるようになりましょう。

深く理解したい人は、前述したハイパーパラメータを色々変えて見て、特徴量空間上の違いを考察して見てください。例えば、学習回数を減らすと、どうなるでしょうか?

回帰問題

ニューラルネットワークは、SVMや決定木と同様に、回帰問題を解くことができます。回帰問題の場合は、MLPRegressorを使用します。以下に、活性化関数としてシグモイド関数を、隠れ層3層の5層階層型ニューラルネットワークを示します。結果の差が出るように、学習回数を100, 500, 1000の3段階を用意して見ました。回帰対象はわかりやすいように、sin関数としました。

In [10]:
from sklearn.neural_network import MLPRegressor

# xからyを推定するモデル
x=np.zeros([50,1])
for i in range(0,50):
    x[i,0]=i*0.2
y = np.sin(x).ravel() # 配列を1次元に変換

# モデル構築

# 学習回数100
Reg_Sigm_100 = MLPRegressor(activation='logistic', max_iter=100, tol = pow(10, -20), hidden_layer_sizes=(100,100,100,),
                        learning_rate_init=0.001, random_state=1)

# 学習回数 500
Reg_Sigm_500 = MLPRegressor(activation='logistic', max_iter=500, tol = pow(10, -20), hidden_layer_sizes=(100,100,100,),
                        learning_rate_init=0.001, random_state=1)

# 学習回数1000
Reg_Sigm_1000 = MLPRegressor(activation='logistic', max_iter=1000, tol = pow(10, -20), hidden_layer_sizes=(100,100,100,),
                        learning_rate_init=0.001, random_state=1)

# 学習
Reg_Sigm_100.fit(x, y)
Reg_Sigm_500.fit(x, y)
Reg_Sigm_1000.fit(x, y)

# 推定
y_Sigm_100 = Reg_Sigm_100.predict(x)
y_Sigm_500 = Reg_Sigm_500.predict(x)
y_Sigm_1000 = Reg_Sigm_1000.predict(x)

# 可視化
plt.scatter(x, y, c='black', label='Train data')
plt.plot(x, y_Sigm_100, c='r', label='Number of Learning: 100')
plt.plot(x, y_Sigm_500, c='b', label='Number of Learning: 500')
plt.plot(x, y_Sigm_1000, c='g', label='Number of Learning: 1000')

plt.xlabel('x')
plt.ylabel('y')
plt.title('NN Regression')
plt.legend(loc="upper right", bbox_to_anchor=(1.6, 1))
plt.show()
/usr/local/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:562: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (100) reached and the optimization hasn't converged yet.
  % self.max_iter, ConvergenceWarning)
/usr/local/lib/python3.7/site-packages/sklearn/neural_network/multilayer_perceptron.py:562: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.
  % self.max_iter, ConvergenceWarning)

警告が2つ出ました。「最大学習回数100, 500に達しました。まだ、最適には収束していません」と書いています。まだ学習しきれていないことを示しています。結果をみると、学習回数が少ないと、まったくsin関数を再現できていないことがわかります。一方、学習回数をあげていくと、正しくsin関数を再現できていることがわかります。このように、NNは実数値を推定する回帰問題を解くこともできます。今回はわかりやすいように、入力を1次元としましたが、多次元も同じ方法で実装可能です。