This repository has been archived by the owner on May 11, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrecurrent_neural_network.py
executable file
·74 lines (55 loc) · 2 KB
/
recurrent_neural_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import matplotlib.pyplot as plt
import neurolab as nl
def get_data(num_points):
# Create sine waveforms
wave_1 = 0.5 * np.sin(np.arange(0, num_points))
wave_2 = 3.6 * np.sin(np.arange(0, num_points))
wave_3 = 1.1 * np.sin(np.arange(0, num_points))
wave_4 = 4.7 * np.sin(np.arange(0, num_points))
# Create varying amplitudes
amp_1 = np.ones(num_points)
amp_2 = 2.1 + np.zeros(num_points)
amp_3 = 3.2 * np.ones(num_points)
amp_4 = 0.8 + np.zeros(num_points)
wave = np.array([wave_1, wave_2, wave_3, wave_4]).reshape(num_points*4, 1)
amp = np.array([amp_1, amp_2, amp_3, amp_4]).reshape(num_points*4, 1)
return wave, amp
# Visualize the output
def visualize_output(nn, num_points_test):
wave, amp = get_data(num_points_test)
output = nn.sim(wave)
plt.plot(amp.reshape(num_points_test*4))
plt.plot(output.reshape(num_points_test*4))
if __name__ == '__main__':
# Create some sample data
num_points = 40
wave, amp = get_data(num_points)
# Create a recurrent neural network with 2 layers
nn = nl.net.newelm([[-2, 2]], [10, 1], [nl.trans.TanSig(), nl.trans.PureLin()])
# Set the init function for each layer
nn.layers[0].initf = nl.init.InitRand([-0.1, 0.1], 'wb')
nn.layers[1].initf = nl.init.InitRand([-0.1, 0.1], 'wb')
nn.init()
# Train the recurrent neural network
err = nn.train(wave, amp, epochs=1200, show=100, goal=0.01)
# Run the training data through the network
output = nn.sim(wave)
# Plot the results
plt.subplot(211)
plt.plot(err)
plt.xlabel('Number of epochs')
plt.ylabel('Error (MSE)')
plt.subplot(212)
plt.plot(amp.reshape(num_points*4))
plt.plot(output.reshape(num_points*4))
plt.legend(['Original', 'Predicted'])
# Testing the network performance on unknown data
plt.figure()
plt.subplot(211)
visualize_output(nn, 82)
plt.xlim([0, 300])
plt.subplot(212)
visualize_output(nn, 49)
plt.xlim([0, 300])
plt.show()