如何通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式(并进行了训练、推理)

计划让AI帮助编程使用TabPFN模型进行股价推理

原计划提问的prompt

如何通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式

本意是想让AI分步执行,先处理股票数据,然后再进行模型训练,结果豆包超长发挥,直接把后面的模型预测部分也写了,导致其它两个模型都需要再附加上模型预测的prompt,

最终除豆包外的prompt

通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务

豆包答题

豆包的总体表现中规中矩,只用一句“数据处理”的prompt,它就完成了整个任务,这是没有预料到的,也就是豆包会抢答了!

但是生成的代码无法一次跑通,大约修改了2-3次,而且有一个地方AI是改不过来的,必须手工改。当然手工改这个地方其它AI也没pass。

豆包的代码

import akshare as ak
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tabpfn import TabPFNClassifier# 获取股票数据
stock_data = ak.stock_zh_a_hist(symbol="600519", period="daily", start_date="20200101", end_date="20230101",adjust="qfq")# 数据预处理函数
def prepare_data_for_tabpfn(stock_data, days_ahead=3):df = stock_data.copy()# 动态处理列名original_columns = df.columns.tolist()print(f"原始列名: {original_columns}")if len(original_columns) == 12:df.columns = ['日期', '开盘', '收盘', '最高', '最低', '成交量', '成交额', '振幅', '涨跌幅', '涨跌额', '换手率', '未知列']else:df.columns = [f"col_{i}" for i in range(len(original_columns))]print(f"警告: 列数量不匹配,使用默认列名: {df.columns.tolist()}")# 特征工程df['收盘价_滞后1'] = df['收盘'].shift(1)df['收盘价_滞后2'] = df['收盘'].shift(2)df['收盘价_滞后3'] = df['收盘'].shift(3)df['涨跌幅_1'] = df['收盘'].pct_change(1)df['涨跌幅_2'] = df['收盘'].pct_change(2)df['涨跌幅_3'] = df['收盘'].pct_change(3)df['MA5'] = df['收盘'].rolling(window=5).mean()df['MA10'] = df['收盘'].rolling(window=10).mean()df['MA20'] = df['收盘'].rolling(window=20).mean()df['波动率_5'] = df['涨跌幅_1'].rolling(window=5).std()df['波动率_10'] = df['涨跌幅_1'].rolling(window=10).std()df['成交量_变化'] = df['成交量'].pct_change(1)df['成交量_MA5'] = df['成交量'].rolling(window=5).mean()df['未来涨跌'] = np.where(df['收盘'].shift(-days_ahead) > df['收盘'], 1, 0)df = df.dropna()feature_columns = ['收盘价_滞后1', '收盘价_滞后2', '收盘价_滞后3','涨跌幅_1', '涨跌幅_2', '涨跌幅_3','MA5', 'MA10', 'MA20','波动率_5', '波动率_10','成交量_变化', '成交量_MA5']valid_features = [col for col in feature_columns if col in df.columns]if len(valid_features) < len(feature_columns):missing_features = set(feature_columns) - set(valid_features)print(f"警告: 缺少以下特征: {missing_features}")feature_columns = valid_featuresX = df[feature_columns].valuesy = df['未来涨跌'].valuesfeature_names = feature_columnsX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, shuffle=False)scaler = StandardScaler()X_train = scaler.fit_transform(X_train)X_test = scaler.transform(X_test)return X_train, X_test, y_train, y_test, feature_names# 准备数据
X_train, X_test, y_train, y_test, feature_names = prepare_data_for_tabpfn(stock_data)# 初始化并训练TabPFN模型
model = TabPFNClassifier(device='cpu')
model.fit(X_train, y_train)# 预测(更新的API调用)
y_pred = model.predict(X_test)
y_prob = model.predict_proba(X_test)[:, 1]  # 获取正类(上涨)的概率# 评估模型
accuracy = np.mean(y_pred == y_test)
print(f"模型准确率: {accuracy:.4f}")# 查看前10个预测结果
print("\n前10个预测结果:")
for i in range(10):print(f"样本 {i+1}: 实际={y_test[i]}, 预测={y_pred[i]}, 概率={y_prob[i]:.4f}")

豆包的输出

原始列名: ['日期', '股票代码', '开盘', '收盘', '最高', '最低', '成交量', '成交额', '振幅', '涨跌幅', '涨跌额', '换手率']模型准确率: 0.5915前10个预测结果:
样本 1: 实际=1, 预测=1, 概率=0.6773
样本 2: 实际=0, 预测=0, 概率=0.4536
样本 3: 实际=1, 预测=1, 概率=0.5055
样本 4: 实际=1, 预测=0, 概率=0.4203
样本 5: 实际=1, 预测=1, 概率=0.6198
样本 6: 实际=1, 预测=1, 概率=0.6163
样本 7: 实际=1, 预测=1, 概率=0.5240
样本 8: 实际=1, 预测=1, 概率=0.5556
样本 9: 实际=0, 预测=1, 概率=0.5418
样本 10: 实际=1, 预测=0, 概率=0.4687

跃问表现

只用前面一句prompt不行,它没像豆包那样帮着把推理、预测也做了,所以要加上一句

prompt

通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务

跃问代码

# 安装必要的库(如果尚未安装)
# pip install akshare pandas sklearn# 导入必要的库
import akshare as ak
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score# 设置股票代码和日期范围
stock_code = "000001" # 示例股票代码
start_date = "2023-01-01"
end_date = "2023-12-31"# 获取历史日线数据
stock_data = ak.stock_zh_a_hist(symbol=stock_code, start_date=start_date, end_date=end_date, adjust="qfq")# 将数据转换为DataFrame
df = pd.DataFrame(stock_data)# 查看数据
print(df.head())# 数据预处理
# 将日期字段转换为日期类型
df["日期"] = pd.to_datetime(df["日期"])# 按日期排序
df = df.sort_values(by="日期")# 计算未来3天的涨跌情况
future_3_day_change = df["收盘价"].shift(-3).sub(df["收盘价"])
df["未来3天涨跌"] = future_3_day_change.apply(lambda x: "涨" if x > 0 else "跌")# 删除NaN值
df.dropna(subset=["未来3天涨跌"], axis=0, how="any")# 查看排序后的数据
print(df.head())# 保存数据为CSV文件
df.to_csv(f"{stock_code}_daily_data.csv", index=False, encoding="utf-8-sig")
print(f"数据已保存到 {stock_code}_daily_data.csv")# 构建分类任务
# 特征和目标变量
X = df[["开盘价", "最高价", "最低价", "成交量"]]
y = df["未来3天涨跌"]# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 构建随机森林分类器
clf = RandomForestClassifier(n_estimators=100, random_state=42)# 训练模型
clf.fit(X_train, y_train)# 预测测试集
y_pred = clf.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率为:{accuracy:%.2f}")

这个有报错,懒得改了

KeyError                                  Traceback (most recent call last)
/tmp/ipykernel_35/3161264376.py in <cell line: 0>()
      1 # 将日期字段转换为日期类型
----> 2 df["日期"] = pd.to_datetime(df["日期"])
      3 
      4 # 按日期排序
      5 df = df.sort_values(by="日期") 

修改一次后还有报错:

AttributeError: 'DataFrame' object has no attribute 'columnss'

放弃了。

 

文心4.5表现

 生成数据集


import akshare as ak
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split# 1. 获取股票数据
def get_stock_data(stock_code="000001", start_date="20200101", end_date="20250604"):df = ak.stock_zh_a_hist(symbol=stock_code, period="daily", start_date=start_date, end_date=end_date)df['日期'] = pd.to_datetime(df['日期'])df = df.sort_values('日期').reset_index(drop=True)return df# 2. 特征工程
def create_features(df, window=5):# 基础特征df['return'] = df['收盘'].pct_change()df['volatility'] = df['return'].rolling(window).std()df['ma5'] = df['收盘'].rolling(5).mean()df['ma10'] = df['收盘'].rolling(10).mean()# 技术指标df['rsi'] = compute_rsi(df['收盘'], 14)df['macd'], df['signal'] = compute_macd(df['收盘'])# 未来3天涨跌标签 (1:涨, 0:跌)df['target'] = (df['收盘'].shift(-3) > df['收盘']).astype(int)# 删除无效行df = df.dropna().reset_index(drop=True)return df# 3. 技术指标计算函数
def compute_rsi(series, period=14):delta = series.diff()gain = delta.where(delta > 0, 0)loss = -delta.where(delta < 0, 0)avg_gain = gain.rolling(period).mean()avg_loss = loss.rolling(period).mean()rs = avg_gain / avg_lossreturn 100 - (100 / (1 + rs))def compute_macd(series, fast=12, slow=26, signal=9):ema_fast = series.ewm(span=fast).mean()ema_slow = series.ewm(span=slow).mean()macd = ema_fast - ema_slowsignal_line = macd.ewm(span=signal).mean()return macd, signal_line# 4. 数据预处理
def preprocess_data(df):features = df.drop(['日期', '股票代码', '开盘', '最高', '最低', '收盘', '成交量', '成交额', '振幅', '涨跌幅', '涨跌额', '换手率'], axis=1, errors='ignore')target = df['target']# 标准化from sklearn.preprocessing import StandardScalerscaler = StandardScaler()features_scaled = scaler.fit_transform(features)return features_scaled, target# 5. 主流程
if __name__ == "__main__":# 获取数据stock_df = get_stock_data("600519")  # 茅台股票# 特征工程feature_df = create_features(stock_df)# 预处理X, y = preprocess_data(feature_df)# 划分训练测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False, random_state=42)# 保存为TabPFN可读格式np.savez('stock_data.npz', X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test)print("数据预处理完成,已保存为stock_data.npz")

训练预测

import numpy as np
from tabpfn import TabPFNClassifier# 加载预处理数据
data = np.load('stock_data.npz')
X_train, X_test = data['X_train'], data['X_test']
y_train, y_test = data['y_train'], data['y_test']# 初始化TabPFN
# classifier = TabPFNClassifier(device='cpu', n_estimators=8)
classifier = TabPFNClassifier(device='cuda', n_estimators=8)
N_ensemble_configurations=32
# 训练模型
classifier.fit(X_train, y_train)# 预测
y_pred = classifier.predict(X_test)
y_prob = classifier.predict_proba(X_test)[:, 1]# 评估
from sklearn.metrics import accuracy_score, roc_auc_score
print(f"测试集准确率: {accuracy_score(y_test, y_pred):.4f}")
print(f"测试集AUC: {roc_auc_score(y_test, y_prob):.4f}")# 预测未来3天涨跌
latest_data = X_test[-1].reshape(1, -1)  # 取最新数据
future_pred = classifier.predict(latest_data)
print(f"未来3天预测结果: {'上涨' if future_pred[0] == 1 else '下跌'}")

 

最终用的是cuda 

输出结果:

测试集准确率: 1.0000
测试集AUC: 1.0000
未来3天预测结果: 下跌

这个,有点问题吧,这训练的也太狠了....

总结

当前的AI都不能小瞧了,跟2年前比都有很大提升。

豆包还是比较有趣,它的理解能力不错,主动性高,但是水平比较平庸。

文心4.5是三者中水平最高的。

跃问,主要是在pandas处理数据那块就出问题了,再修改了一次之后,我就放弃了。

调试

发现几个ai都碰到了这个错误

TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_35/4228700019.py in <cell line: 0>()
      8 
      9 # 初始化TabPFN
---> 10 classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)
     11 
     12 # 训练模型

TypeError: TabPFNClassifier.__init__() got an unexpected keyword argument 'N_ensemble_configurations'

通过help(TabPFNClassifier) ,看到应该是:

n_estimators

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/908106.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[蓝桥杯]最大化股票交易的利润

最大化股票交易的利润 题目描述 实现一个算法寻找最大化股票交易利润的策略。介绍如下&#xff1a; 股票价格每天都在变化&#xff0c;以数组的索引表示交易日&#xff0c;以数组的元素表示每天的股票价格。可以通过买入和卖出获得利润。一天只能进行一次买入或卖出操作&…

URL 结构说明+路由(接口)的认识

一、URL 结构说明 以这个为例&#xff1a;http://127.0.0.1:5000/zhouleifeng 1.组成部分: http://&#xff1a;协议 127.0.0.1&#xff1a;主机&#xff08;本地地址&#xff09; :5000&#xff1a;端口号&#xff08;Flask 默认 5000&#xff09; /zhouleifeng&#xff1a…

微服务商城-用户微服务

数据表 用户表 CREATE DATABASE user; USE user;CREATE TABLE user (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 用户ID,username varchar(50) NOT NULL DEFAULT COMMENT 用户名,password varchar(50) NOT NULL DEFAULT COMMENT 用户密码&#xff0c;MD5加密…

Java面试题及答案整理( 2025年最新版,持续更新...)

最近发现网上很多Java面试题都没有答案&#xff0c;所以花了很长时间搜集整理出来了这套Java面试题大全&#xff0c;希望大家能够喜欢&#xff01; 注&#xff1a;篇幅有限&#xff0c;资料已整理成文档&#xff0c;后台si我666&#xff0c;我一个个发&#xff01; 这套面试文…

[论文阅读]PPT: Backdoor Attacks on Pre-trained Models via Poisoned Prompt Tuning

PPT: Backdoor Attacks on Pre-trained Models via Poisoned Prompt Tuning PPT: Backdoor Attacks on Pre-trained Models via Poisoned Prompt Tuning | IJCAI IJCAI-22 发表于2022年的论文&#xff0c;当时大家还都在做小模型NLP的相关工作&#xff08;BERT&#xff0c;Ro…

Redis最佳实践——性能优化技巧之集群与分片

Redis集群与分片在电商应用中的性能优化技巧 一、Redis集群架构模式解析 1. 主流集群方案对比 方案核心原理适用场景电商应用案例主从复制读写分离数据冗余中小规模读多写少商品详情缓存Redis Sentinel自动故障转移监控高可用需求场景订单状态缓存Redis Cluster原生分布式分片…

Vue 生命周期全解析:从创建到销毁的完整旅程

Vue 生命周期是每个 Vue 开发者必须深入理解的核心概念之一。它定义了组件从创建、挂载、更新、销毁的整个过程&#xff0c;以及在这个过程中各个阶段提供的钩子函数。掌握生命周期不仅能帮助你理解 Vue 的工作原理&#xff0c;还能让你在合适的时机执行特定的操作&#xff0c;…

【Rust 高级trait】Rust trait的一些高级用法解密

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

联想电脑护眼卫士与系统颜色配置(X-Rite)冲突 | 显示设置频繁变换色阶 - 解决方案

联想电脑护眼卫士与系统颜色配置X-Rite冲突 | 显示设置频繁变换色阶 - 解决方案 前言方案1&#xff1a;解决联想护眼卫士方案2&#xff1a;解决系统颜色配置(X-Rite) 前言 自带X-Rite软件的联想电脑&#xff08;以拯救者Y9000P&#xff0c;Win11系统为例&#xff09;&#xff…

MySQL中SELECT查询的执行顺序

MySQL中SELECT查询的执行顺序 在日常的数据库开发中&#xff0c;我们经常会写各种复杂的SELECT查询语句。然而&#xff0c;很多开发者对于MySQL实际执行这些查询的顺序并不完全了解。理解查询的执行顺序不仅有助于编写更高效的SQL语句&#xff0c;还能帮助我们更好地优化查询性…

es 的字段类型(text和keyword)

Text 当一个字段是要被全文检索时&#xff0c;比如 Email 内容、产品描述&#xff0c;这些字段应该使用 text 类型。设置 text 类型以后&#xff0c;字段内容会被分析&#xff0c;在生成倒排索引之前&#xff0c;字符串会被分析器分词。text类型的字段不用于排序&#xff0c;很…

MySQL安装及启用详细教程(Windows版)

MySQL安装及启用详细教程&#xff08;Windows版&#xff09; &#x1f4cb; 概述 本文档将详细介绍MySQL数据库在Windows系统下的下载、安装、配置和启用过程。 &#x1f4e5; MySQL下载 官方下载地址 官方网站: https://dev.mysql.com/downloads/社区版本: https://dev.my…

Linux下使用nmcli连接网络

Linux下使用nmcli连接网络 介绍 在使用ubuntu系统的时候&#xff0c;有时候不方便使用桌面&#xff0c;使用ssh远程连接&#xff0c;可能需要使用nmcli命令来连接网络。本文将介绍如何使用nmcli命令连接网络。nmcli 是 NetworkManager 的命令行工具&#xff0c;用于管理网络连…

Python----循环神经网络(BiLSTM:双向长短时记忆网络)

一、LSTM 与 BiLSTM对比 1.1、LSTM LSTM&#xff08;长短期记忆网络&#xff09; 是一种改进的循环神经网络&#xff08;RNN&#xff09;&#xff0c;专门解决传统RNN难以学习长期依赖的问题。它通过遗忘门、输入门和输出门来控制信息的流动&#xff0c;保留重要信息并丢弃无关…

U盘挂载Linux

在 只能使用 Telnet 的情况下&#xff0c;如果希望通过 U盘 传输文件到 Linux 系统&#xff0c;可以按照以下步骤操作&#xff1a; &#x1f4cc; 前提条件 U盘已插入 Linux 主机的 USB 接口。Linux 主机支持自动挂载 U盘&#xff08;大多数现代发行版默认支持&#xff09;。T…

QuickBASIC QB64 支持 64 位系统和跨平台Linux/MAC OS

QuickBASIC 的现代继任者 QB64 已发展成为一个功能强大的开源项目&#xff0c;支持 64 位系统和跨平台开发。以下是详细介绍&#xff1a; 项目首页 - QB64pe:The QB64 Phoenix Edition Repository - GitCode https://gitcode.com/gh_mirrors/qb/QB64pe 1. QB64 概述 官网&am…

【C++高级主题】命令空间(五):类、命名空间和作用域

目录 一、实参相关的查找&#xff08;ADL&#xff09;&#xff1a;函数调用的 “智能搜索” 1.1 ADL 的核心规则 1.2 ADL 的触发条件 1.3 ADL 的典型应用场景 1.4 ADL 的潜在风险与规避 二、隐式友元声明&#xff1a;类与命名空间的 “私密通道” 2.1 友元声明的基本规则…

免费开源Umi-OCR,离线使用,批量精准!

Umi-OCR&#xff08;Windows端&#xff09; Umi-OCR 是一款在 GitHub 上开源的免费 OCR 识别软件&#xff0c;它最大的亮点就是免费、开源、支持批量处理&#xff0c;而且识别准确度很高。这款软件不需要联网就能用&#xff0c;非常值得推荐&#xff01; 在 OCR 识别功能方面&…

深入剖析 Docker 容器化原理与实战应用,开启技术新征程!

文章目录 前言一、为什么 是Docker &#xff1f;二、Docker 容器化原理分析2.1 镜像&#xff08;Image&#xff09;2.2 容器&#xff08;Container&#xff09;2.3 仓库&#xff08;Registry&#xff09; 三、Docker 容器化实践3.1 Docker安装3.2 创建一个 Docker 镜像3.3 运行…

黑马程序员TypeScript课程笔记—class篇

class的基本使用 class的构造函数&#xff08;实现实例属性的初始化&#xff09; 在使用构造函数的时候&#xff0c;小括号的后面不要指定类型&#xff0c;否则就会报错&#xff0c;因为构造函数没有返回值 class实例方法 class继承&#xff08;extends&#xff09; class继承…