网站首页 > 文章精选 正文
这是我的第274篇原创文章。
一、引言
对于表格数据,一套完整的机器学习建模流程如下:
针对不同的数据集,有些步骤不适用,其中橘红色框为必要步骤,欢迎大家关注翻看我之前的一些相关文章。前面我介绍了机器学习模型的二分类任务,接下来做一个机器学习模型的回归任务系列,由于本系列案例数据质量较高,有些步骤跳过了,跳过的步骤将单独出文章总结!在Python中,可以使用Scikit-learn库来构建GBDT回归模型进行预测,本文以预测房价为例,对这个过程做一个简要解读。
二、实现过程
2.1 读取数据
filename = 'data.csv'
dataset = pd.read_csv(filename, names=names, delim_whitespace=True)
df = pd.DataFrame(dataset)
df:
2.2 数据集划分
features = names[:-1]
target = ['MEDV']
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=0)
2.3 数据归一化
# 无需该步骤
2.4 建模预测
model = GradientBoostingRegressor(random_state=0).fit(X_train, y_train)
y_train_pred = model.predict(X_train)
y_test_pred = model.predict(X_test)
2.5 结果可视化
# 训练集预测值与真实值的对比
plt.plot(list(range(0,len(X_train))),y_train,marker='o')
plt.plot(list(range(0,len(X_train))),y_train_pred,marker='*')
plt.legend(['真实值','预测值'])
plt.xlabel('序列')
plt.ylabel('房价')
plt.title('训练集预测值与真实值的对比')
plt.show()
结果:
# 验证集预测值与真实值的对比
plt.plot(list(range(0,len(X_test))),y_test,marker='o')
plt.plot(list(range(0,len(X_test))),y_test_pred,marker='*')
plt.legend(['真实值','预测值'])
plt.xlabel('序列')
plt.ylabel('房价')
plt.title('验证集预测值与真实值的对比')
plt.show()
结果:
2.6 评价指标
# 评价指标
trainScore1 = math.sqrt(mean_squared_error(y_train, y_train_pred))
print('Train Score: %.2f RMSE' % (trainScore1))
testScore1 = math.sqrt(mean_squared_error(y_test, y_test_pred))
print('Test Score: %.2f RMSE' % (testScore1))
trainScore2 = mean_absolute_error(y_train, y_train_pred)
print('Train Score: %.2f MAE' % (trainScore2))
testScore2 = mean_absolute_error(y_test, y_test_pred)
print('Test Score: %.2f MAE' % (testScore2))
trainScore3 = r2_score(y_train, y_train_pred)
print('Train Score: %.2f R2' % (trainScore3))
testScore3 = r2_score(y_test, y_test_pred)
print('Test Score: %.2f R2' % (testScore3))
trainScore4 = mean_absolute_percentage_error(y_train, y_train_pred)
print('Train Score: %.2f MAPE' % (trainScore4))
testScore4 = mean_absolute_percentage_error(y_test, y_test_pred)
print('Test Score: %.2f MAPE' % (testScore4))
结果打印:
作者简介: 读研期间发表6篇SCI数据算法相关论文,目前在某研究院从事数据算法相关研究工作,结合自身科研实践经历持续分享关于Python、数据分析、特征工程、机器学习、深度学习、人工智能系列基础知识与案例。关注gzh:数据杂坛,获取数据和源码学习更多内容。
原文链接:
猜你喜欢
- 2025-07-02 深度学习中的损失函数(损失函数原理)
- 2025-07-02 【Python机器学习系列】一文教你建立SVR模型预测房价(源码)
- 2025-07-02 【Python机器学习系列】一文教你建立LightGBM模型预测房价
- 2025-07-02 如何利用数字化手段有效监测换流阀冷却系统?这里有答案
- 2025-07-02 【Python机器学习系列】一文教你建立随机森林模型预测房价
- 2025-07-02 AI回归模型评估指标:MSE、RMSE、MAE、R2
- 2025-07-02 一文彻底搞懂自动机器学习AutoML:TPOT
- 2025-07-02 Matlab和Python环境下的深度学习小项目(第二篇)
- 2025-07-02 使用Flask应用框架在Centos7.8系统上部署机器学习模型
- 2025-07-02 【机器学习】数据挖掘神器LightGBM详解(附代码)
- 最近发表
- 标签列表
-
- newcoder (56)
- 字符串的长度是指 (45)
- drawcontours()参数说明 (60)
- unsignedshortint (59)
- postman并发请求 (47)
- python列表删除 (50)
- 左程云什么水平 (56)
- 计算机网络的拓扑结构是指() (45)
- 编程题 (64)
- postgresql默认端口 (66)
- 数据库的概念模型独立于 (48)
- 产生系统死锁的原因可能是由于 (51)
- 数据库中只存放视图的 (62)
- 在vi中退出不保存的命令是 (53)
- 哪个命令可以将普通用户转换成超级用户 (49)
- noscript标签的作用 (48)
- 联合利华网申 (49)
- swagger和postman (46)
- 结构化程序设计主要强调 (53)
- 172.1 (57)
- apipostwebsocket (47)
- 唯品会后台 (61)
- 简历助手 (56)
- offshow (61)
- mysql数据库面试题 (57)