TensorFlow 简单的二分类神经网络的训练和应用流程

news/2025/2/2 1:26:19 标签: tensorflow, 神经网络, 人工智能

展示了一个简单的二分类神经网络的训练和应用流程。主要步骤包括:

1. 数据准备与预处理

2. 构建模型

3. 编译模型

4. 训练模型

5. 评估模型

6. 模型应用与部署

加载和应用已训练的模型


1. 数据准备与预处理

在本例中,数据准备是通过两个 Numpy 数组来完成的:

  • x:输入特征,形状为 (8, 2),包含 8 个数据点,每个数据点有 2 个特征。
  • y:标签,形状为 (8,),包含对应的 0 或 1 标签,表示每个输入点的类别。
x = np.array([[1, 1], [1, -1], [-1, 1], [-1, -1], [0.7, 0.7], [0.7, -0.7], [-0.7, -0.7], [-0.7, 0.7]])
y = np.array([1, 1, 1, 1, 0, 0, 0, 0])

2. 构建模型

使用 Keras 的 Sequential 模型来构建神经网络。此模型包含两个全连接层(Dense 层):

  • 第一个 Dense 层有 3 个单位,激活函数是 Sigmoid。
  • 第二个 Dense 层有 1 个单位,激活函数是 Sigmoid,输出层的激活函数将模型输出的值映射到 0 到 1 之间,适合二分类任务。
l1 = tf.keras.layers.Dense(units=3, activation='sigmoid')
l2 = tf.keras.layers.Dense(units=1, activation='sigmoid')
model = tf.keras.Sequential([l1, l2])

3. 编译模型

在编译阶段,我们选择了优化器、损失函数和评估指标:

  • 优化器:SGD(随机梯度下降),学习率设置为 0.9。
  • 损失函数:binary_crossentropy,适用于二分类任务。
  • 评估指标:accuracy,表示训练过程中对分类准确率的衡量。
sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])

4. 训练模型

通过 model.fit() 函数来训练模型。我们传入训练数据 x 和标签 y,并设置训练的 epoch 数量为 2000。

model.fit(x, y, epochs=2000)

5. 评估模型

在此示例中,评估部分通过训练后的 model 来进行,并没有显式写出 evaluate() 函数。评估通常是在训练之后,通过测试集或验证集对模型性能进行评估,具体可以使用 model.evaluate() 来查看损失和准确度。

6. 模型应用与部署

训练完成后,我们保存了训练好的模型。保存后的模型可以被加载和应用于新的数据集。

model.save('my_model.keras')  # 保存模型

7.加载和应用已训练的模型

加载保存的模型,并用其对新数据进行预测。model.predict() 方法返回的是预测的概率值,我们通过设置阈值(如 0.9)将其转换为类别(0 或 1)。

model = tf.keras.models.load_model('my_model.keras')  # 加载模型
nx = np.array([[2, 2], [0.1, 0.1], [1.1, 1.2], [0.3, 0.3]])  # 新的输入数据
predictions = model.predict(nx)  # 获取预测结果
print(predictions)  # 输出概率

# 将概率转化为类别
predicted_classes = (predictions > 0.9).astype(int)
print(predicted_classes)  # 输出最终的类别预测

8.完整代码
test.py 训练模型

import tensorflow as tf
import numpy as np
# 创建int32类型的0维张量,即标量
l1=tf.keras.layers.Dense(units=3,activation='sigmoid')
l2=tf.keras.layers.Dense(units=1,activation='sigmoid')
model=tf.keras.Sequential([l1,l2])
sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])
x=np.array([[1,1],[1,-1],[-1,1],[-1,-1],[0.7,0.7],[0.7,-0.7],[-0.7,-0.7],[-0.7,0.7]])
y=np.array([1,1,1,1,0,0,0,0])
model.fit(x,y,epochs=2000)
# 保存训练好的模型(Keras 格式)
model.save('my_model.keras')

 test2.py加载模型并进行预测:

import tensorflow as tf
import numpy as np

# 加载训练好的模型
model = tf.keras.models.load_model('my_model.keras')

# 预测数据
nx = np.array([[2, 2], [0.1, 0.1], [1.1, 1.2], [0.3, 0.3]])

# 获取预测结果
predictions = model.predict(nx)

# 输出预测结果
print(predictions)

# 如果需要将概率转化为类别(0或1)
predicted_classes = (predictions > 0.9).astype(int)

# 输出最终的类别预测
print(predicted_classes)

9.视频分享


初识TensorFlow 
https://v.douyin.com/ifG2mmLH/
复制此链接,打开Dou音搜索,直接观看视频!


http://www.niftyadmin.cn/n/5839687.html

相关文章

离散化C++

离散化(Discretization)是一种将连续数据映射到离散值的技术,常用于将大数据范围压缩到较小范围,例如将数值映射到索引。离散化在算法竞赛中常用于处理数值范围较大但数据量较小的问题(如区间问题、统计问题等&#xf…

JMeter中常见的四种参数化实现方式是什么?_file test_params

2 参数化实现 2.1 CSV Data Set Config 在JMeter中提起参数化,我们默认就想到CSV Data Set Config(以下简称CSV),CSV能够读取文件中的数据并生成变量,被JMeter脚本引用,从而实现参数化。下面我们来详细探究…

S4 HANA税码科目确定(OB40)

本文主要介绍在S4 HANA OP中税码科目确定(OB40)相关设置。具体请参照如下内容: 税码科目确定(OB40) 在以上界面维护“Transaction Key”的记账码。 在以上界面进一步维护“Transaction Key”确定科目的规则。 Chart of Account:用于明确该规则适用于什么科目表。 …

Python 列表(使用列表时避免索引错误)

你将学习列表是什么以及如何使用列表元素。列表让你能够在一个地方存储成组的信息,其中可以只包含几个元素,也可以包含数百万个元素。 列表是新手可直接使用的最强大的Python功能之一,它融合了众多重要的编程概念。 使用列表时避免索引错误 …

mysql中in和exists的区别?

大家好,我是锋哥。今天分享关于【mysql中in和exists的区别?】面试题。希望对大家有帮助; mysql中in和exists的区别? 在 MySQL 中,IN 和 EXISTS 都是用于子查询的操作符,但它们在执行原理和适用场景上有所不…

Janus-Pro 论文解读:DeepSeek 如何重塑多模态技术格局

Janus-Pro:多模态领域的璀璨新星——技术解读与深度剖析 一、引言 在人工智能的浩瀚星空中,多模态理解与生成模型犹如耀眼的星座,不断推动着技术边界的拓展。Janus-Pro作为这一领域的新兴力量,以其卓越的性能和创新的架构&#x…

加一(66)

66. 加一 - 力扣&#xff08;LeetCode&#xff09; 解法&#xff1a; class Solution { public:vector<int> plusOne(vector<int>& digits) {bool plus_one true;for (int i digits.size() - 1; i > 0; --i) {if (plus_one) {int tmp digits[i] 1;if …

pytorch逻辑回归实现垃圾邮件检测

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 完整代码&#xff1a; import torch import torch.nn as nn import torch.optim as optim from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.model_selection import train_test_split …