inch-blog

Welcome to Inch-blog ! Home is a place where you can read mainly technical articles. LIFE is mainly about my personal life.

Pytorch の Cross Entropy Loss について解説

解説

Pytorch の Cross Entropy Loss について具体的にどういう計算をしているかをコードとともに示す.

交差エントロピー誤差を式で表現すると次のようになります

H(p,q)=xp(x)log(q(x))H(p, q) = - \sum_{x} p(x) \log{(q(x))}

なお、この log\log は自然対数(底が ee )です。

pp は真の確率分布、qq は推定した確率分布です。

ppqq の確率分布が似ていると交差エントロピー誤差は小さくなり、似ていないと交差エントロピー誤差は大きくなります。

式だけだとわかりにくいので、わかりやすいように例をあげます。今、写真に映っている果物が、

  • バナナ
  • りんご
  • みかん

のどれかを予測するとします。

写真に映っている果物がバナナの場合、真の確率分布 pp(1, 0, 0)となります。

true = torch.tensor([[1, 0, 0]], dtype=torch.float)
# true = torch.tensor([0]) # ラベルのインデックスのみを指定する場合はこっち

# true
# tensor([[1., 0., 0.]])

そして、推定した確率分布 qq(0.8, 0.1, 0.1)だったとします。

pred_list = [0.8, 0.1, 0.1]
pred = torch.tensor([[np.log(p) for p in pred_list]], requires_grad=True)

# pred
# tensor([[-0.2231, -2.3026, -2.3026]], dtype=torch.float64, requires_grad=True)

この場合、交差エントロピー誤差は次のようになります。

H(p,q)=(1×loge0.8+0×loge0.1+0×loge0.1)=loge0.8=(0.2231)=0.2231\begin{aligned} H(p, q) &= - ( 1 \times \log_{e}{0.8} + 0 \times \log_{e}{0.1} + 0 \times \log_{e}{0.1} ) \\ &= - \log_{e}{0.8} \\ &= - (-0.2231) \\ &= 0.2231 \end{aligned}

loss = criterion(pred, true)

# loss
# tensor(0.2231, dtype=torch.float64, grad_fn=<NegBackward0>)

一方、推定した確率分布 q(x)q(x)(0.3, 0.4, 0.3)だったとします。

pred_list = [0.3, 0.4, 0.3]
pred = torch.tensor([[np.log(p) for p in pred_list]], requires_grad=True)

# pred
# tensor([[-1.2040, -0.9163, -1.2040]], dtype=torch.float64, requires_grad=True)

この場合、交差エントロピー誤差は次のようになります。

H(p,q)=(1×loge0.3+0×loge0.4+0×loge0.4)=loge0.3=(1.204)=1.204\begin{aligned} H(p, q) &= - ( 1 \times \log_{e}{0.3} + 0 \times \log_{e}{0.4} + 0 \times \log_{e}{0.4} ) \\ &= - \log_{e}{0.3} \\ &= - (-1.204) \\ &= 1.204 \end{aligned}

loss = criterion(pred, true)

# loss
# tensor(1.2040, dtype=torch.float64, grad_fn=<NegBackward0>)

交差エントロピー誤差は後者よりも前者の方が小さくなっており、真の確率分布 pp に近いのは前者の確率分布 qq であることが分かります。

直感とも一致するので、交差エントロピー誤差が損失関数として適していることが分かります。

ソースコード全文

サンプル 1
import torch
import torch.nn as nn
import numpy as np

criterion = nn.CrossEntropyLoss(reduction="sum")
true = torch.tensor([[1, 0, 0]], dtype=torch.float)
# true = torch.tensor([0])
pred_list = [0.8, 0.1, 0.1]
pred = torch.tensor([[np.log(p) for p in pred_list]], requires_grad=True)
loss = criterion(pred, true)

print(true)
print(pred)
print(loss)

# tensor([[1., 0., 0.]])
# tensor([[-0.2231, -2.3026, -2.3026]], dtype=torch.float64, requires_grad=True)
# tensor(0.2231, dtype=torch.float64, grad_fn=<NegBackward0>)
サンプル 2
criterion = nn.CrossEntropyLoss(reduction="sum")
true = torch.tensor([[1, 0, 0]], dtype=torch.float)
# true = torch.tensor([0])
pred_list = [0.3, 0.4, 0.3]
pred = torch.tensor([[np.log(p) for p in pred_list]], requires_grad=True)
loss = criterion(pred, true)

print(true)
print(pred)
print(loss)

# tensor([[1., 0., 0.]])
# tensor([[-1.2040, -0.9163, -1.2040]], dtype=torch.float64, requires_grad=True)
# tensor(1.2040, dtype=torch.float64, grad_fn=<NegBackward0>)