1 min readJun 15, 2020
Here it is:
import matplotlib.pyplot as pltdata = np.array([[i, j, forward(model, np.array([[i, j]]))[0]] for i in np.linspace(0, 1, 20) for j in np.linspace(0, 1, 20)])fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(data[:,0], data[:,1], c=data[:,2])
plt.show()