解説
Pytorch の Cross Entropy Loss について具体的にどういう計算をしているかをコードとともに示す.
交差エントロピー誤差を式で表現すると次のようになります
なお、この は自然対数(底が )です。
は真の確率分布、 は推定した確率分布です。
と の確率分布が似ていると交差エントロピー誤差は小さくなり、似ていないと交差エントロピー誤差は大きくなります。
式だけだとわかりにくいので、わかりやすいように例をあげます。今、写真に映っている果物が、
- バナナ
- りんご
- みかん
のどれかを予測するとします。
写真に映っている果物がバナナの場合、真の確率分布 は(1, 0, 0)
となります。
true = torch.tensor([[1, 0, 0]], dtype=torch.float)
# true = torch.tensor([0]) # ラベルのインデックスのみを指定する場合はこっち
# true
# tensor([[1., 0., 0.]])
そして、推定した確率分布 が(0.8, 0.1, 0.1)
だったとします。
pred_list = [0.8, 0.1, 0.1]
pred = torch.tensor([[np.log(p) for p in pred_list]], requires_grad=True)
# pred
# tensor([[-0.2231, -2.3026, -2.3026]], dtype=torch.float64, requires_grad=True)
この場合、交差エントロピー誤差は次のようになります。
loss = criterion(pred, true)
# loss
# tensor(0.2231, dtype=torch.float64, grad_fn=<NegBackward0>)
一方、推定した確率分布 が(0.3, 0.4, 0.3)
だったとします。
pred_list = [0.3, 0.4, 0.3]
pred = torch.tensor([[np.log(p) for p in pred_list]], requires_grad=True)
# pred
# tensor([[-1.2040, -0.9163, -1.2040]], dtype=torch.float64, requires_grad=True)
この場合、交差エントロピー誤差は次のようになります。
loss = criterion(pred, true)
# loss
# tensor(1.2040, dtype=torch.float64, grad_fn=<NegBackward0>)
交差エントロピー誤差は後者よりも前者の方が小さくなっており、真の確率分布 に近いのは前者の確率分布 であることが分かります。
直感とも一致するので、交差エントロピー誤差が損失関数として適していることが分かります。
ソースコード全文
サンプル 1
import torch
import torch.nn as nn
import numpy as np
criterion = nn.CrossEntropyLoss(reduction="sum")
true = torch.tensor([[1, 0, 0]], dtype=torch.float)
# true = torch.tensor([0])
pred_list = [0.8, 0.1, 0.1]
pred = torch.tensor([[np.log(p) for p in pred_list]], requires_grad=True)
loss = criterion(pred, true)
print(true)
print(pred)
print(loss)
# tensor([[1., 0., 0.]])
# tensor([[-0.2231, -2.3026, -2.3026]], dtype=torch.float64, requires_grad=True)
# tensor(0.2231, dtype=torch.float64, grad_fn=<NegBackward0>)
サンプル 2
criterion = nn.CrossEntropyLoss(reduction="sum")
true = torch.tensor([[1, 0, 0]], dtype=torch.float)
# true = torch.tensor([0])
pred_list = [0.3, 0.4, 0.3]
pred = torch.tensor([[np.log(p) for p in pred_list]], requires_grad=True)
loss = criterion(pred, true)
print(true)
print(pred)
print(loss)
# tensor([[1., 0., 0.]])
# tensor([[-1.2040, -0.9163, -1.2040]], dtype=torch.float64, requires_grad=True)
# tensor(1.2040, dtype=torch.float64, grad_fn=<NegBackward0>)