Skip to content

Commit 701c772

Browse files
committed
3d graph
1 parent dd925a7 commit 701c772

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

Summer20/NeuralNetwork/tf3d.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from tensorflow.keras import datasets, models, layers, losses
2+
import tensorflow as tf
3+
from mpl_toolkits import mplot3d
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
import random
7+
8+
def cone(x,y):
9+
return np.sqrt(x**2 + y**2)
10+
11+
def ripple(x,y):
12+
return np.sin(10 * (x**2 + y**2)) / 10
13+
14+
def makeTuple(X,Y):
15+
inputList = []
16+
for index, value in enumerate(X):
17+
for index1, value1 in enumerate(value):
18+
inputList.append([value1, Y[index][index1]])
19+
return inputList
20+
21+
def unpackTuple(A):
22+
X = []
23+
Y = []
24+
for item in A:
25+
X.append([item[0]])
26+
Y.append([item[1]])
27+
return X, Y
28+
29+
def makeArray(Z):
30+
zList = []
31+
for subList in Z:
32+
for value in subList:
33+
zList.append([value])
34+
return zList
35+
36+
def randomPoints(number, bounds):
37+
inputList = []
38+
outputList = []
39+
while(number > 0):
40+
value1 = random.uniform(bounds[0],bounds[1])
41+
value2 = random.uniform(bounds[0],bounds[1])
42+
inputList.append([value1, value2])
43+
outputList.append([ripple(value1, value2)])
44+
number = number -1
45+
return inputList, outputList
46+
47+
bounds = (-1,1)
48+
inputList, outputList = randomPoints(50000, bounds)
49+
X_Train, Y_Train = unpackTuple(inputList)
50+
51+
model = models.Sequential()
52+
model.add(layers.Dense(32, activation='exponential', input_shape=(2,)))
53+
model.add(layers.Dense(48, activation='tanh'))
54+
model.add(layers.Dense(1, activation=None))
55+
model.compile(optimizer='Adam',
56+
loss=losses.MeanSquaredError(),
57+
metrics=['mean_squared_error'])
58+
59+
history = model.fit(np.array(inputList),np.array(outputList), epochs=300)
60+
#print(model.get_weights())
61+
62+
63+
# plots out learning curve
64+
# plt.plot(history.history['mean_squared_error'], label='mean_squared_error')
65+
# plt.xlabel('Epoch')
66+
# plt.ylabel('MSE')
67+
# plt.ylim([0.0, 0.2])
68+
# plt.legend(loc='lower right')
69+
# plt.show()
70+
71+
# generate test data
72+
inputTest, outputTest = randomPoints(10, bounds)
73+
X_Test, Y_Test = unpackTuple(inputTest)
74+
print(model.predict(np.array(inputTest)))
75+
print(outputTest)
76+
77+
x = np.linspace(-1, 1, 800)
78+
y = np.linspace(-1, 1, 800)
79+
80+
X, Y = np.meshgrid(x, y)
81+
Z = ripple(X, Y)
82+
83+
fig = plt.figure()
84+
ax = plt.axes(projection="3d")
85+
86+
ax.plot_wireframe(X, Y, Z, color='c')
87+
ax.scatter3D(X_Test, Y_Test, model.predict(np.array(inputTest)), c='r')
88+
ax.set_xlabel('x')
89+
ax.set_ylabel('y')
90+
ax.set_zlabel('z')
91+
92+
plt.show()

0 commit comments

Comments
 (0)