KNN模型

news/2024/7/7 21:06:15
使用K-Nearest Neighbors (KNN)算法进行分类。首先加载一个数据集,然后进行预处理,选择最佳的K值,并训练一个KNN模型。
# encoding=utf-8
import numpy as np
datas = np.loadtxt('datingTestSet2.txt')  # 加载数据集,返回一个numpy数组
# 提取特征和标签
x_data = datas[:, 0:3]  # 提取前三列数据作为特征
y_data = datas[:, 3]  # 提取第四列数据作为标签
print('标准化前:', x_data)  # 特征矩阵
print(y_data)  # 标签向量
# 数据maxmin标准化
from sklearn.preprocessing import MinMaxScaler  # 用于数据的标准化
std = MinMaxScaler()  # 创建一个MinMaxScaler对象
x_data = std.fit_transform(x_data)  # 标准化
print('标准化:', x_data)
# 拆分数据集(训练集和测试集)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2,
                                                    random_state=123)  # 测试集占总数据的20%,随机种子设为123以保证结果的可重复性
# 建立KNN模型
from sklearn.neighbors import KNeighborsClassifier
# 使用交叉验证法评估模型性能
from sklearn.model_selection import cross_val_score
k_range = range(1, 31)  # 创建一个范围从1到30的序列,用于试验不同的K值。
k_error = []  # 创建一个空列表,用于存储每个K值对应的错误率。
# 找最合适的k,既平均值最高
for k in k_range:
    model_kun = KNeighborsClassifier(n_neighbors=k)
    scores = cross_val_score(model_kun, x_train, y_train, cv=6, scoring="accuracy") 
    # 将数据集分成6个子集
    # 估计方法对象 数据特征 数据标签 几折交叉验证
    meanscores = scores.mean()  # 平均值
    k_error.append(1 - meanscores)  # 将准确率的平均值转换为错误率
    print("k=", k, "meanscores=", meanscores)
# 可视化K值和错误率的关系
import matplotlib.pyplot as plt
plt.plot(k_range, k_error)  # 绘制K值与错误率的图像
plt.show()
# 建立KNN分类器模型,并使用训练集进行训练
model_kun = KNeighborsClassifier(n_neighbors=9)  # n_neighbors=9表示在预测时,KNN分类器将考虑最近的9个邻居,并根据这9个邻居中最常见的类别来预测输入样本的类别
model_kun.fit(x_train, y_train)  # 使用训练集对模型进行训练
scores = model_kun.score(x_test, y_test)  # 使用测试集评估模型性能,返回准确率
print('准确率为:', scores)

 

 

 

 

 使用KNN算法加载鸢尾花数据集

# 加载鸢尾花数据集
from sklearn.datasets import load_iris

iris = load_iris()
print(iris)
x_data = iris.data  # 样本数据
y_data = iris.target  # 标签数据
print("标准化前:", x_data)


# 数据maxmin标准化
from sklearn.preprocessing import MinMaxScaler

mms = MinMaxScaler()
x_data = mms.fit_transform(x_data)
print(x_data)

# 拆分数据集(训练集和测试集)
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(x_data, y_data,   test_size=0.2,random_state=123)

# 建立knn模型
from sklearn.neighbors import KNeighborsClassifier

from sklearn.model_selection import cross_val_score

k_range=range(1,31)
k_error=[] #错误率

# 找最合适的k,既平均值最高
for k in k_range:
    model_kun=KNeighborsClassifier(n_neighbors=k)
    scores=cross_val_score(model_kun,x_train,y_train,cv=6,scoring="accuracy")
    # 估计方法对象 数据特征 数据标签 几折交叉验证
    meanscores=scores.mean()    # 平均值
    k_error.append(1-meanscores)    # 错误率
    print("k=",k,"meanscores=",meanscores)

# 将k的值和错误率可视化出来,比较好找
import matplotlib.pyplot as plt
plt.plot(k_range,k_error)
plt.show()

model_knn = KNeighborsClassifier(n_neighbors=10)

model_knn.fit(x_train, y_train)
scores = model_knn.score(x_test, y_test)  # 准确率
print(scores)

 


http://lihuaxi.xjx100.cn/news/1738228.html

相关文章

Hafnium之传递启动数据给SP和SP启动顺序

安全之安全(security)博客目录导读 目录 一、将启动数据传递给SP 二、SP启动顺序 一、将启动数据传递给SP

threejs(10)-WEBGL与GPU渲染原理(难点)后期再消化亦可

一、渲染管线 WebGL 是什么 WebGL (Web图形库)是一个JavaScript API,可在任何兼容的Web浏览器中渲染高性能的交互式3D和2D图形,而无需使用插件。WebGL通过引入一个与OpenGL ES 2.0非常一致的API来做到这一点,该API可以在HTML5 元素中使用。这种一致性使API可以利用用户设备提…

Java,面向对象,多态性

多态性是面向对象的第三大重要特征,建立在继承性之上。 多态性一词怎么理解呢?就是一个事物的多种形态的性质。在面向对象中,主要体现为一个父类的属性方法可以继承给多个子类。子类就理解为父类的多种形态。以动物为例,猫和狗都有…

【技术综述】深度学习模型结构复杂、参数众多,如何更直观地深入理解你的模型?...

CNN、RNN等深度学习模型使用的门槛虽然低,但模型参数多,网络结构复杂。输出如何关联模型的参数,在数学上没有很直观的解释,导致模型网络结构的设计以及训练过程中超参数的调试,都非常依赖于经验。结果不好,…

Combination Sum IV【中等难度】

Combination Sum IV【中等难度】 以下是一道力扣中等难度的题目:Combination Sum IV 题目描述: 给定一个由正整数组成的数组 nums 和一个正整数 target,请找出总和为 target 的不同组合的数量。组合中的数字可以在组合中出现任意次。 示例: 输入: nums = [1, 2, 3], tar…

3 — NLP 中的标记化:分解文本数据的艺术

一、说明 这是一个系列文章的第三篇文章, 文章前半部分分别是: 1 — NLP 的文本预处理技术2 — NLP中的词干提取和词形还原:文本预处理技术 在本文中,我们将介绍标记化主题。在开始之前,我建议您阅读我之前介绍…

微信小程序开发-微信支付退款功能【附有完整代码】

之前有写过详细的微信支付功能:微信支付 我们使用weixin-java-pay的jar包等,配置上的流程同微信支付,可以看上面的文章。 退款使用的WxPayService类的refundV3方法。使用该方法需要在微信支付配置的基础上加上:apiclient_key.pem…

[USACO23OPEN] Field Day S题解

远古的回忆。 把变换一个字符视为边权为 1 1 1 的边&#xff0c;即求最长路。 最长路不好搞&#xff0c;考虑转补集最短路&#xff08;容易感性理解&#xff09;&#xff0c;BFS 即可。 #include<bits/stdc.h> #define int long long using namespace std;const int …