K 值选对,准确率翻倍:KNN 算法调参的黄金法则

目录

 

一、背景介绍

二、KNN 算法原理

2.1 核心思想

2.2 距离度量方法

2.3 算法流程

2.4算法结构:

三、KNN 算法代码实现

3.1 基于 Scikit-learn 的简单实现

3.2 手动实现 KNN(自定义代码)

四、K 值选择与可视化分析

4.1 K 值对分类结果的影响

4.2 交叉验证选择最优 K 值

五、KNN 算法的优缺点与优化

5.1 优点

5.2 缺点

5.3 优化方法

六、KNN 算法的应用场景

七、KNN 与其他算法的对比

八、小结


 

一、背景介绍

K 近邻算法(K-Nearest Neighbors, KNN)是机器学习中最简单、最直观的算法之一,其核心思想源于人类对相似事物的判断逻辑 ——“近朱者赤,近墨者黑”。该算法无需复杂的训练过程,直接通过计算样本间的距离来进行分类或回归,广泛应用于图像识别、文本分类、推荐系统等领域。

二、KNN 算法原理

2.1 核心思想

KNN 的核心思想是:对于一个待预测样本,找到训练数据中与其最相似的 K 个样本(近邻),根据这 K 个样本的类别(分类问题)或数值(回归问题)进行投票或平均,从而确定待预测样本的类别或数值。

关键点

相似性度量:通过距离函数衡量样本间的相似性。

K 值选择:近邻数量 K 对结果影响显著。

投票机制:分类问题通常采用多数投票,回归问题采用均值或加权平均。

2.2 距离度量方法

常见的距离度量方法包括:

欧氏距离:适用于连续变量,计算两点间的直线距离。

曼哈顿距离:适用于城市网格路径等场景,计算两点间的折线距离。

余弦相似度:适用于文本、图像等高维数据,衡量向量间的方向相似性。

2.3 算法流程

KNN 算法的典型流程如下:

1·数据预处理:对数据进行清洗、归一化,避免特征量纲影响距离计算。

2·计算距离:计算待预测样本与所有训练样本的距离。

3·选择近邻:按距离升序排列,选取前 K 个最近邻样本。

4·分类 / 回归决策

分类:统计 K 个近邻的类别,选择出现次数最多的类别。

回归:计算 K 个近邻数值的平均值或加权平均值。

2.4算法结构:

三、KNN 算法代码实现

3.1 基于 Scikit-learn 的简单实现

以鸢尾花数据集(Iris Dataset)为例,演示 KNN 分类的完整流程。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data[:, :2]  # 仅取前两个特征,便于可视化
y = iris.target
feature_names = iris.feature_names[:2]# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 创建KNN分类器(K=5)
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)# 预测测试集
y_pred = knn.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy with K=5: {accuracy:.2f}")  # 输出:Accuracy with K=5: 0.98

3.2 手动实现 KNN(自定义代码)

为深入理解算法原理,我们手动实现 KNN 分类器:

class CustomKNN:def __init__(self, n_neighbors=3):self.n_neighbors = n_neighborsdef fit(self, X_train, y_train):self.X_train = X_trainself.y_train = y_traindef predict(self, X_test):predictions = []for x in X_test:# 计算距离distances = [np.sqrt(np.sum((x - x_train)**2)) for x_train in self.X_train]# 获取最近的K个样本索引k_indices = np.argsort(distances)[:self.n_neighbors]# 获取对应的类别k_nearest_labels = self.y_train[k_indices]# 多数投票most_common = np.bincount(k_nearest_labels).argmax()predictions.append(most_common)return np.array(predictions)# 使用自定义KNN
custom_knn = CustomKNN(n_neighbors=3)
custom_knn.fit(X_train, y_train)
y_pred_custom = custom_knn.predict(X_test)
print(f"Custom KNN Accuracy: {accuracy_score(y_test, y_pred_custom):.2f}")  # 输出:0.96

四、K 值选择与可视化分析

4.1 K 值对分类结果的影响

K 值是 KNN 算法的核心超参数,其大小直接影响分类结果:

  • K 值过小:模型复杂度高,易受噪声影响,导致过拟合。
  • K 值过大:模型趋于平滑,可能忽略局部特征,导致欠拟合。

示例:在鸢尾花数据集上,不同 K 值的分类边界差异如下:

def plot_decision_boundary(clf, X, y, title, k=None):plt.figure(figsize=(8, 6))x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),np.arange(y_min, y_max, 0.02))Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])Z = Z.reshape(xx.shape)plt.contourf(xx, yy, Z, alpha=0.8)# 绘制散点图for i, color in zip([0, 1, 2], ['r', 'g', 'b']):idx = np.where(y == i)plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i], edgecolor='k')plt.xlabel(feature_names[0])plt.ylabel(feature_names[1])plt.title(f"KNN Decision Boundary (K={k})")plt.legend()plt.show()# K=1(过拟合)
knn1 = KNeighborsClassifier(n_neighbors=1)
knn1.fit(X_train, y_train)
plot_decision_boundary(knn1, X_test, y_test, "K=1", k=1)# K=15(欠拟合)
knn15 = KNeighborsClassifier(n_neighbors=15)
knn15.fit(X_train, y_train)
plot_decision_boundary(knn15, X_test, y_test, "K=15", k=15)

4.2 交叉验证选择最优 K 值

通过交叉验证可以有效选择最优 K 值:

from sklearn.model_selection import cross_val_score# 候选K值
k_values = range(1, 31)
cv_scores = []for k in k_values:knn = KNeighborsClassifier(n_neighbors=k)scores = cross_val_score(knn, X_train, y_train, cv=5, scoring='accuracy')cv_scores.append(scores.mean())# 绘制K值与准确率曲线
plt.plot(k_values, cv_scores, marker='o', linestyle='--', color='b')
plt.xlabel('K Value')
plt.ylabel('Cross-Validation Accuracy')
plt.title('K Value Selection via Cross-Validation')
plt.show()

五、KNN 算法的优缺点与优化

5.1 优点

简单易懂:原理直观,无需复杂数学推导。

无需训练:直接使用训练数据进行预测。

泛化能力强:对非线性数据分布有较好的适应性。

5.2 缺点

计算复杂度高:预测时需计算与所有训练样本的距离。

存储成本高:需存储全部训练数据。

对噪声敏感:K 值过小时,异常值可能显著影响结果。

5.3 优化方法

数据预处理:归一化、特征选择。

近似最近邻搜索:KD 树、球树等加速算法。

加权投票:根据距离赋予不同权重。

六、KNN 算法的应用场景

  • 图像识别与分类:常用于手写数字识别、人脸识别等任务。
  •  推荐系统:基于用户或物品的相似度进行推荐。
  •  医疗诊断:根据患者的临床指标预测疾病类别。
  •  异常检测:通过判断样本与近邻的距离识别异常点。

七、KNN 与其他算法的对比

算法核心思想优点缺点适用场景
KNN基于相似性投票 / 平均简单直观、无需训练计算慢、存储成本高、高维性能差小规模数据、实时预测
逻辑回归基于概率的线性分类训练快、可解释性强仅适用于线性可分数据、需调参二分类、概率预测
决策树基于特征划分的树结构分类可解释性强、能处理非线性数据易过拟合、对噪声敏感分类规则提取、快速预测

八、小结

KNN 算法以其简单性和直观性成为机器学习入门的经典算法,适用于小规模、低维数据的快速分类 / 回归任务。尽管存在计算效率和高维性能的局限,但其思想为许多复杂算法提供了基础。通过数据预处理、近似搜索和加权机制,KNN 的实用性可进一步提升;未来,随着硬件计算能力的提升和近似搜索算法的发展,KNN 在大规模数据中的应用可能迎来新突破。结合深度学习的特征提取能力,可构建更强大的混合模型。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/82440.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Azure DevOps Server 2022.2 补丁(Patch 5)

微软Azure DevOps Server的产品组在4月8日发布了2022.2 的第5个补丁。下载路径为:https://aka.ms/devops2022.2patch5 这个补丁的主要功能是修改了代理(Agent)二进制安装文件的下载路径;之前,微软使用这个CND(域名为vstsagentpackage.azuree…

PHP7+MySQL5.6 查立得轻量级公交查询系统

# PHP7MySQL5.6 查立得轻量级公交查询系统 ## 系统简介 本系统是一个基于PHP7和MySQL5.6的轻量级公交查询系统(40KB级),支持线路查询、站点查询和换乘查询功能。系统采用原生PHPMySQL开发,无需第三方框架,适合手机端访问。 首发版本&#x…

Vue-Cropper:全面掌握图片裁剪组件

Vue-Cropper 完全学习指南:Vue图片裁剪组件 🎯 什么是 Vue-Cropper? Vue-Cropper 是一个简单易用的Vue图片裁剪组件,支持Vue2和Vue3。它提供了丰富的配置选项和回调方法,可以满足各种图片裁剪需求。 🌟 …

[Go] Option选项设计模式 — — 编程方式基础入门

[Go] Option选项设计模式 — — 编程方式基础入门 全部代码地址,欢迎⭐️ Github:https://github.com/ziyifast/ziyifast-code_instruction/tree/main/go-demo/go-option 1 介绍 在 Go 开发中,我们经常遇到需要处理多参数配置的场景。传统方…

【Unity开发】控制手机移动端的震动

🐾 个人主页 🐾 阿松爱睡觉,横竖醒不来 🏅你可以不屠龙,但不能不磨剑🗡 目录 一、前言二、Unity的Handheld.Vibrate()三、调用Android原生代码四、NiceVibrations插件五、DeviceVibration插件六、控制游戏手…

Linux 软件安装方式全解(适用于 CentOS/RHEL 系统)

🐧 Linux 软件安装方式全解(适用于 CentOS/RHEL 系统) 在 Linux 系统中,软件安装方式丰富多样,常见于以下几种方式: 安装方式命令/工具说明软件包管理器(推荐)yum, dnf, apt, zypp…

前端面试题-HTML篇

1. 请谈谈你对 Web 标准以及 W3C 的理解和认识。 我对 Web 标准 的理解是,它就像是互联网世界的“交通规则”,由 W3C(World Wide Web Consortium,万维网联盟) 这样一个国际性组织制定。这些规则规范了我们在编写 HTML、CSS 和 JavaScript 时应该遵循的语法和行为,比如要…

ERROR: column cl.udt_name does not exist LINE 1 navicat打开金仓表报错

描述: ERROR: column cl.udt_name does not exist LINE 1: …a.columns cl LEFT JOlN pg type ty ON ty.typname cl.udt nam. navicat连上金仓数据库之后,想打开一张表看看,每张表都报这个错,打不开 解决方案: 网上…

2025年- H61-Lc169--74.搜索二维矩阵(二分查找)--Java版

1.题目描述 2.思路 方法一: 定义其实坐标,右上角的元素(0,n-1)。进入while循环(注意边界条件,行数小于m,列数要>0)从右上角开始开始向左遍历(比当…

Jupyter MCP服务器部署实战:AI模型与Python环境无缝集成教程

Jupyter MCP 服务器是基于模型上下文协议(Model Context Protocol, MCP)的 Jupyter 环境扩展组件,它能够实现大型语言模型与实时编码会话的无缝集成。该服务器通过标准化的协议接口,使 AI 模型能够安全地访问和操作 Jupyter 的核心…

MySQL下载安装配置环境变量

MySQL下载安装配置环境变量 文章目录 MySQL下载安装配置环境变量一、安装MySQL1.1 下载1.2 安装 二、查看MySQL服务是否启动三、配置环境变量四、验证 一、安装MySQL 1.1 下载 官网社区版(免费版):https://dev.mysql.com/downloads/mysql/ …

WSL 安装 Debian 12 后,Linux 如何安装 curl , quickjs ?

在 WSL 的 Debian 12 系统中安装 curl 非常简单,你可以直接使用 APT 包管理器从官方仓库安装。以下是详细步骤: 1. 更新软件包索引 首先确保系统的包索引是最新的: sudo apt update2. 安装 curl 执行以下命令安装 curl: sudo…

Linux入门(十四)rpmyum

RPM 是RedHat PackManager的缩写 rpm是用于互联网下载包的打包及安装工具 rpm查询 查询已安装的rpm列表 rpm -qa查看系统是否安装了psmisc rpm -qa | grep psmisc rpm -q psmisc查询软件包信息 rpm -qi psmisc查询软件包中的文件 rpm -ql psmisc根据文件全路径 查询文件所…

[git]忽略.gitignore文件

git rm --cached .gitignore 是一个 Git 命令,主要用于 从版本控制中移除已追踪的 .gitignore 文件,但保留该文件在本地工作目录中。以下是详细解析: 一、命令拆解与核心作用 语法解析 git rm:Git 的删除命令,用于从版本库(Repository)中移除文件。--cached:关键参数…

Hive SQL 中 BY 系列关键字全解析:从排序、分发到分组的核心用法

一、排序与分发相关 BY 关键字 1. ORDER BY:全局统一排序 作用:对查询结果进行全局排序,确保最终结果集完全有序(仅允许单个 Reducer 处理数据)。 语法: SELECT * FROM table_name ORDER BY column1 [A…

网络爬虫 - App爬虫及代理的使用(十一)

App爬虫及代理的使用 一、App抓包1. App爬虫原理2. reqable的安装与配置1. reqable安装教程2. reqable的配置3. 模拟器的安装与配置1. 夜神模拟器的安装2. 夜神模拟器的配置4. 内联调试及注意事项1. 软件启动顺序2. 开启抓包功能3. reqable面板功能4. 夜神模拟器设置项5. 注意事…

【25.06】FISCOBCOS使用caliper自定义测试 通过webase 单机四节点 helloworld等进行测试

前置条件 安装一个Ubuntu20+的镜像 基础环境安装 Git cURL vim jq sudo apt install -y git curl vim jq Docker和Docker-compose 这个命令会自动安装docker sudo apt install docker-compose sudo chmod +x /usr/bin/docker-compose docker versiondocker-compose vers…

【基础】Unity中Camera组件知识点

一、投影模式 (Projection) 1. 透视模式 (Perspective) 原理:模拟人眼,近大远小(锥形体视锥) 核心参数: Field of View (FOV):垂直视场角 典型值:第一人称 60-90,驾驶舱 30-45 特…

PCA(K-L变换)人脸识别(python实现)

数据集分析 ORL数据集, 总共40个人,每个人拍摄10张人脸照片 照片格式为灰度图像,尺寸112 * 92 特点: 图像质量高,无需灰度运算、去噪等预处理 人脸已经位于图像正中央,但部分图像角度倾斜(可…

【Git】View Submitted Updates——diff、show、log

在 Git 中查看更新的内容(即工作区、暂存区或提交之间的差异)是日常开发中的常见操作。以下是常用的命令和场景说明: 文章目录 1、查看工作区与暂存区的差异2、查看提交历史中的差异3、查看工作区与最新提交的差异4、查看两个提交之间的差异5…