Skip to content

Commit dd925a7

Browse files
committed
difficult
1 parent b41fbbf commit dd925a7

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

Summer20/NeuralNetwork/bigfunc.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from tensorflow.keras import datasets, models, layers, losses
2+
import tensorflow as tf
3+
import numpy as np
4+
import random
5+
import math
6+
import matplotlib.pyplot as plt
7+
8+
def generatePoints(number, bounds):
9+
inputList = []
10+
outputList = []
11+
while(number>0):
12+
value = random.uniform(bounds[0],bounds[1])
13+
inputList.append([value])
14+
outputList.append([math.sin(value**2)])
15+
number = number - 1
16+
return inputList, outputList
17+
18+
bounds = (-5,5) # represents full system dynamics
19+
inputList, outputList = generatePoints(30000, bounds)
20+
21+
# neural network code
22+
model = models.Sequential()
23+
model.add(layers.Dense(32, activation='tanh', input_shape=(1,)))
24+
model.add(layers.Dense(48, activation='tanh'))
25+
model.add(layers.Dense(1, activation=None))
26+
model.compile(optimizer='Adam',
27+
loss=losses.MeanSquaredError(),
28+
metrics=['mean_squared_error'])
29+
30+
history = model.fit(np.array(inputList),np.array(outputList), epochs=300)
31+
#print(model.get_weights())
32+
33+
# plots out learning curve
34+
plt.plot(history.history['mean_squared_error'], label='mean_squared_error')
35+
plt.xlabel('Epoch')
36+
plt.ylabel('MSE')
37+
plt.ylim([0.0, 0.5])
38+
plt.legend(loc='lower right')
39+
plt.show()
40+
41+
42+
bounds = (-7,7)
43+
# generate test data
44+
inputTest, outputTest = generatePoints(20, bounds)
45+
print(model.predict(np.array(inputTest)))
46+
print(outputTest)
47+
48+
graph = plt.figure()
49+
ax = graph.add_subplot(111)
50+
51+
x = np.linspace(-8,8,500)
52+
y = np.sin(x**2)
53+
54+
plt.plot(x,y, label= 'y = sin(x^2)', markersize = 2, c='c')
55+
ax.scatter(inputList, outputList, label = 'training', c='b')
56+
ax.scatter(inputTest,model.predict(np.array(inputTest)), label = 'testing', c='r')
57+
plt.legend(loc='lower right')
58+
plt.show()

0 commit comments

Comments
 (0)