-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathChapter10.2.py
More file actions
52 lines (47 loc) · 1.92 KB
/
Copy pathChapter10.2.py
File metadata and controls
52 lines (47 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import math
import numpy as np
import matplotlib.pyplot as plt
n = 90
c = 3
y = np.ones((int(n / c), c))
for i in range(c):
y[:, i] = y[:, i] * i
y = np.column_stack(y).flatten().reshape(n, 1)
x = np.random.randn(int(n / c), c) + np.tile(np.linspace(-3, 3, c), (int(n / c), 1))
x = np.column_stack(x).flatten().reshape(n, 1)
hh = 2 * 1 ** 2
x2 = x ** 2
learning_rate = 0.1
N = 100
X = np.linspace(-5, 5, N).T.reshape(-1, 1)
k = np.exp(-(np.tile(x2, (1, n)) + np.tile(x2.T, (n, 1)) - 2 * x.dot(x.T)) / hh)
K = np.exp(-(np.tile(X ** 2, (1, n)) + np.tile(x2.T, (N, 1)) - 2 * X.dot(x.T)) / hh)
Kt = np.empty((N, 3))
for i in range(0, c):
yk = (y == i)
ky = k[:, np.where(yk == 1)[0]]
ty = np.linalg.lstsq((ky.T.dot(ky) + learning_rate * np.eye(np.sum(yk))), ky.T.dot(yk), rcond=-1)[0]
Kt[:, i] = np.maximum(0, K[:, np.where(yk == 1)[0]].dot(ty)).reshape((-1, ))
ph = Kt / np.tile(np.sum(Kt, 1).reshape(-1, 1), (1, c))
plt.plot(X, ph[:, 0], 'b-', label='q(y=1|x)')
plt.plot(X, ph[:, 1], 'r--', label='q(y=2|x)')
plt.plot(X, ph[:, 2], 'g:', label='q(y=3|x)')
plt.plot(x[np.where(y == 0)[0]], -0.1 * np.ones((int(n / c), 1)), 'bo')
plt.plot(x[np.where(y == 1)[0]], -0.2 * np.ones((int(n / c), 1)), 'rx')
plt.plot(x[np.where(y == 2)[0]], -0.1 * np.ones((int(n / c), 1)), 'gv')
plt.axis([-5, 5, -0.3, 1.8])
plt.legend()
plt.show()
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression(penalty='l2', solver="lbfgs", random_state=42)
log_reg.fit(x, y.ravel())
y_pred_pro = log_reg.predict_proba(X)
plt.plot(X, y_pred_pro[:, 0], 'b-', label='q(y=1|x)')
plt.plot(X, y_pred_pro[:, 1], 'r--', label='q(y=2|x)')
plt.plot(X, y_pred_pro[:, 2], 'g:', label='q(y=3|x)')
plt.plot(x[np.where(y == 0)[0]], -0.1 * np.ones((int(n / c), 1)), 'bo')
plt.plot(x[np.where(y == 1)[0]], -0.2 * np.ones((int(n / c), 1)), 'rx')
plt.plot(x[np.where(y == 2)[0]], -0.1 * np.ones((int(n / c), 1)), 'gv')
plt.axis([-5, 5, -0.3, 1.8])
plt.legend()
plt.show()