[python] 深度学习基础
人工神经网络实现鸢尾花分类(一)
人工神经网络实现鸢尾花分类(二)
人工神经网络实现鸢尾花分类(三)
人工神经网络实现鸢尾花分类(四)
人工神经网络实现鸢尾花分类(五)
本文主要写人工神经网络实现鸢尾花分类代码部分
使用的是Kaggle: Your Machine Learning and Data Science Community在线编译器
本代码属于自己造螺丝类,写的很细,没用神经网络不必要的函数即相关模块。相对较为复杂。
下篇文章会写道利用模块和自带函数实现鸢尾花分类。相对简单很多。
目录
本文主要写人工神经网络实现鸢尾花分类代码部分
鸢尾花数据集(Iris)
主要分为六大块
导入所需模块
准备数据
数据集读入
数据集乱序
生成训练集和测试集(即 x_train / y_train)
配成 (输入特征,标签) 对,每次读入一小撮(batch)
搭建网络
定义神经网路中所有可训练参数
参数优化
嵌套循环迭代,with结构更新参数,显示当前loss
测试效果
acc / loss可视化
输出结果
鸢尾花数据集(Iris)
主要分为六大块
导入包-->准备数据-->搭建网络-->参数优化-->测试效果-->可视化
导入所需模块
import tensorflow as tf
from sklearn import datasets
from matplotlib import pyplot as plt
import numpy as np
准备数据
数据集读入x_data = datasets.load_iris().data
y_data = datasets.load_iris().target
数据集乱序np.random.seed(116)
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
生成训练集和测试集(即 x_train / y_train)x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
配成 (输入特征,标签) 对,每次读入一小撮(batch)train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
搭建网络
定义神经网路中所有可训练参数w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1))
lr = 0.1
train_loss_results = []
test_acc = []
epoch = 1000
loss_all = 0
参数优化
嵌套循环迭代,with结构更新参数,显示当前lossfor epoch in range(epoch):
for step, (x_train, y_train) in enumerate(train_db):
with tf.GradientTape() as tape:
y = tf.matmul(x_train, w1) + b1
y = tf.nn.softmax(y)
y_ = tf.one_hot(y_train, depth=3)
loss = tf.reduce_mean(tf.square(y_ - y))
loss_all += loss.numpy()
grads = tape.gradient(loss, [w1, b1])
w1.assign_sub(lr * grads[0])
b1.assign_sub(lr * grads[1])
train_loss_results.append(loss_all / 4)
loss_all = 0
测试效果
计算当前参数前向传播后的准确率,显示当前acc
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
y = tf.matmul(x_test, w1) + b1
y = tf.nn.softmax(y)
pred = tf.argmax(y, axis=1)
pred = tf.cast(pred, dtype=y_test.dtype)
correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
correct = tf.reduce_sum(correct)
total_correct += int(correct)
total_number += x_test.shape[0]
acc = total_correct / total_number
test_acc.append(acc)
acc / loss可视化
plt.figure(figsize=(10, 10))
plt.title('Loss and Acc')
plt.xlabel('Epoch')
plt.ylabel('Loss and Acc')
plt.yticks([0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0,1.1])
plt.plot(train_loss_results, label="$Loss$")
plt.plot(test_acc, label="$Accuracy$")
plt.legend()
plt.show()
输出结果
相关知识
深度学习及其应用
深度学习花卉识别:Python数据集解析
深度学习应用开发
神经网络与深度学习
python毕业设计 深度学习昆虫识别系统
【Python】基础
《Python机器学习开发实战》电子书在线阅读
使用Python实现深度学习模型:智能农业病虫害检测与防治
基于YOLOv8深度学习的智能肺炎诊断系统【python源码+Pyqt5界面+数据集+训练代码】深度学习实战
opencv深度学习昆虫识别系统图像识别 python
网址: [python] 深度学习基础 https://www.huajiangbk.com/newsview840301.html
上一篇: 龚炯:不用看民调!这一关键信息预 |
下一篇: 民调显示84.7%受访者确认“网 |
推荐分享

- 1君子兰什么品种最名贵 十大名 4012
- 2世界上最名贵的10种兰花图片 3364
- 3花圈挽联怎么写? 3286
- 4迷信说家里不能放假花 家里摆 1878
- 5香山红叶什么时候红 1493
- 6花的意思,花的解释,花的拼音 1210
- 7教师节送什么花最合适 1167
- 8勿忘我花图片 1103
- 9橄榄枝的象征意义 1093
- 10洛阳的市花 1039