My Very First Model Trained . This Python script demonstrates the creation and training of a simple neural network for linear regression using TensorFlow and NumPy.
Run the script using a Python interpreter. Upon execution, the script will prompt the user to train the neural network. If the user chooses to train the model (y
), the script will fit the neural network to the given data. Otherwise, it will print a warning that the neural network is not trained.
python simple_linear_regression_nn.py
The neural network consists of one layer with one neuron, making it a simple linear regression model. The model is compiled with Stochastic Gradient Descent (sgd
) as the optimizer and Mean Squared Error (mean_squared_error
) as the loss function.
model = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
The input data (xs
) and corresponding output data (ys
) are provided for training the neural network. The script uses NumPy to create arrays for input and output data.
xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([-2.0, 1.0, 4.0, 7.0, 10.0, 13.0], dtype=float)
The user is prompted to decide whether to train the neural network. If training is requested, the model is fitted to the input and output data using the model.fit
method.
user_request = input("Do you want to train the neural network? (y/n) ")
if user_request.lower() == "y":
model.fit(xs, ys, epochs=500)
print("Training complete.")
else:
print("Training skipped. (Beware: The neural network is not trained.)")
After training or if the user chooses to skip training, the script makes a prediction using the trained model. In this example, the script predicts the output for the input value 10.0
.
print(model.predict([10.0]))
The script utilizes the following Python modules:
- TensorFlow - An open-source machine learning framework.
- NumPy - A powerful library for numerical operations in Python.
Ensure these modules are installed in your Python environment before running the script.
pip install tensorflow numpy