2019年5月30日木曜日

Python - 機械学習の基礎 - 強化学習 - Q学習

コード

Python 3

```#!/usr/bin/env python3
import random
import matplotlib.pyplot as plt

gen_max = 1000
node = 15
alpha = 0.1
gamma = 0.9
epsilon = 0.3

qvalue = [random.randrange(101) for _ in range(node)]
qvalues = [[] for _ in range(node - 1)]
for i in range(gen_max):
s = 0
for _ in range(3):
if random.random() < epsilon:
if random.randrange(2) == 0:
s = 2 * s + 1
else:
s = 2 * s + 2
else:
if qvalue[2 * s + 1] > qvalue[2 * s + 2]:
s = 2 * s + 1
else:
s = 2 * s + 2
if s == 14:
qvalue[s] = qvalue[s] + alpha * (1000 - qvalue[s])
elif s == 11:
qvalue[s] = qvalue[s] + alpha * (500 - qvalue[s])
elif s < 7:
if qvalue[2 * s + 1] > qvalue[2 * s + 2]:
qmax = qvalue[2 * s + 1]
else:
qmax = qvalue[2 * s + 2]
qvalue[s] = qvalue[s] + alpha * (gamma * qmax - qvalue[s])
if i % 100 == 0:
for n in qvalue[1:]:
print(int(n), end=' ')
print()
for j, q in enumerate(qvalue[1:]):
qvalues[j].append(q)

xys = []
for qs in qvalues:
xys.extend([range(1, gen_max + 1), qs])
plt.plot(*xys)
plt.legend([f'Q{i}' for i in range(1, gen_max + 1)])
# plt.show()
plt.savefig('sample2.png')
```

```C:\Users\...>py qlearning.py
46 41 41 47 96 35 30 7 93 74 6 39 27 100
75 149 31 83 242 73 30 7 93 74 386 39 27 343
75 458 31 83 448 638 30 7 93 74 499 39 27 901
75 807 30 83 449 898 30 7 93 74 499 39 27 999
75 809 29 83 449 899 30 7 93 74 499 39 27 999
75 809 29 83 449 899 30 7 93 74 499 39 27 999
75 809 29 83 449 899 30 7 93 74 499 39 27 999
75 809 29 83 449 899 30 7 93 74 499 39 27 999
75 809 28 83 449 899 30 7 93 74 499 39 27 999
75 809 28 83 449 899 30 7 93 74 499 39 27 999

C:\Users\...>
```