基于LSTM的股票价格预测

image-20240702100614494

摘要

本课设旨在利用LSTM(长短期记忆)网络实现股票价格预测,通过收集、预处理股票数据集,并构建预测模型进行训练与优化。实验结果显示,经过优化调整模型参数,模型在测试集上取得了较为理想的预测效果。尽管存在部分预测不准确的情况,总体而言,该模型在股票价格预测任务中表现良好,具有实际应用的潜力和效果。

导入必要的库

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

加载数据集

该项目所用数据来自飞浆开源数据集,数据集采用的是上证指数的股票
数据集包含10列 (股票来源、日期、开盘价、收盘价、最低价、最高价、交易量、交易额、跌涨幅、后一天最高价) ,共有6109天的股票数据
该项目中,我们利用历史数据中的开盘价、收盘价、最低价、最高价、交易量、交易额、跌涨幅来对下一日的最高价进行预测

数据集中所有数据都来自同一支股票,同时按照时间顺序排列好, 从1990年12月20日到2015年12月10日, 共6106条数据

#导入数据
data=pd.read_csv(r'.\datasets\stock_dataset.csv')
df=pd.DataFrame(data)
dataset = df.iloc[:,2:].to_numpy()
df.head()

image-20240702100842471

股票价格走势图像 这里我们对股票的每日的最高价格进行显示

df=pd.DataFrame(data,columns=['high'])
plt.plot(df)
plt.show()

image-20240702100922265

数据预处理

数据预处理, 这里因为所有数据都是存在的,所以不用再检查缺失值

# 3.1 得到训练数据与对应label
X = np.array(dataset[:,:-1])
y = np.array(dataset[:,-1])
# 3.2 标准化处理,归一化
st = StandardScaler()
X = st.fit_transform(X)
y = y / 1000
# 3.3 划分训练集和测试集 按照9:1的概率划分
X_train = X[0:int(len(X)*0.9),:]
y_train = y[0:int(len(y)*0.9)]
X_test = X[int(len(X)*0.9):,:]
y_test = y[int(len(y)*0.9):]
# 3.4 定义 PyTorch Dataset 类
class MyDataset(Dataset):
    def __init__(self, x, y, sequence_length):
        self.x = x
        self.y = y
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.x) - self.sequence_length

    def __getitem__(self, idx):
        return (
            torch.tensor(self.x[idx:idx+self.sequence_length], dtype=torch.float),
            torch.tensor(self.y[idx+self.sequence_length], dtype=torch.float),
        )
# 3.5 根据划分的训练集测试集生成需要的时间序列样本数据, 预测长度定为14,及根据前13天数据 预测后一天数据
sequence_length = 14
dataset_train = MyDataset(X_train, y_train, sequence_length)
train_dataloader = DataLoader(dataset_train, batch_size=64, shuffle=True)
dataset_test = MyDataset(X_test, y_test, sequence_length)
test_dataloader = DataLoader(dataset_test, batch_size=64, shuffle=False)  # 不需要打乱测试集

搭建模型

# 4.1 LSTM 模型
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.init_weights()

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])  # 使用最后一个时间步的输出
        return out

    def init_weights(self):
        # 设置随机数种子
        torch.manual_seed(42)

        # 遍历 LSTM 层的参数,对参数进行初始化
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.normal_(param, mean=0, std=0.1)  # 使用正态分布初始化权重
            elif 'bias' in name:
                nn.init.constant_(param, 0)  # 将偏置项初始化为零

训练模型

# 5.1 初始化模型、损失函数和优化器# 初始化最佳模型参数和最佳验证损失
best_model_params = model.state_dict()
best_val_loss = float('inf')
# 5.2 训练模型
num_epochs = 50
train_loss_list = []
val_loss_list = []
for epoch in range(num_epochs):
    train_loss = 0
    for batch_input, batch_target in train_dataloader:
        optimizer.zero_grad()
        output = model(batch_input)
        loss = criterion(output.squeeze(), batch_target)
        train_loss+=loss
        loss.backward()
        optimizer.step()
    train_loss = train_loss/len(train_dataloader)
    train_loss_list.append(train_loss.item())

    # 在验证集上计算损失并保存最佳模型
    with torch.no_grad():
        val_losses = []
        for val_batch_input, val_batch_target in test_dataloader:
            val_output = model(val_batch_input)
            val_loss = criterion(val_output.squeeze(), val_batch_target)
            val_losses.append(val_loss.item())
        avg_val_loss = np.mean(val_losses)
        val_loss_list.append(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_params = model.state_dict()
            torch.save(best_model_params, './best_model_LSTM.pth', _use_new_zipfile_serialization=False)

    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss.item():.4f}, Validation Loss: {avg_val_loss:.4f}')

image-20240702101215341

# 5.3 绘制损失曲线
plt.title("Training loss curve")
plt.plot(train_loss_list)
plt.xlabel("Epoch")
plt.ylabel("Loss value")
plt.show()

image-20240702101233993

模型推理

# 6.1 加载模型
model_lstm = LSTMModel(input_size, hidden_size, num_layers, output_size)
model_lstm.load_state_dict(torch.load('./best_model_LSTM.pth'))

# 6.2 将数据集重新按照长度为14的序列进行划分, 以此划分方式,前sequence_length-1个的结果不会被预测,所以真实值中也应去除
res = []
for idx in range(0, len(X)- sequence_length):
    res.append(X[idx:idx+sequence_length])
res = torch.stack(res, dim=0)

# 6.3 开始推理
with torch.no_grad():
    val_output = model_lstm(res)
# 6.4 将推理结果 与 实际结果绘图进行比较
plt.plot(val_output*1000, label='predict')
plt.plot(y[sequence_length:]*1000, label='real')
plt.xlabel("date")
plt.ylabel("value")
plt.legend()
plt.show()

image-20240702101315885

完整项目

数据集、代码、报告

https://mbd.pub/o/bread/Zpeck5lt

精品学习专栏导航帖

机器学习项目实战 项目详解 + 数据集 + 完整源码+ 项目报告

基于YOLO的目标检测系统(PyQT页面+模型+数据集)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/780443.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

《征服数据结构》SparseArray

摘要&#xff1a; 1&#xff0c;SparseArray的介绍 2&#xff0c;SparseArray的代码实现 1&#xff0c;SparseArray的介绍 前面我们讲过《ArrayMap》&#xff0c;用它来实现哈希表&#xff0c;其中存放key和value的数组长度是存放散列表数组长度的二倍。 在哈希表中如果key值是…

SwiftData 模型对象的多个实例在 SwiftUI 中不能及时同步的解决

概览 我们已经知道,用 CoreData 在背后默默支持的 SwiftUI 视图在使用 @FetchRequest 来查询托管对象集合时,若查询结果中的托管对象在别处被改变将不会在 FetchedResults 中得到及时的刷新。 那么这一“囧境”在 SwiftData 里是否也会“卷土重来”呢?空说无益,就让我们在…

【项目设计】负载均衡式——Online Judge

负载均衡式——Online Judge&#x1f60e; 前言&#x1f64c;Online Judge 项目一、项目介绍二、项目技术栈三、项目使用环境四、项目宏观框架五、项目后端服务实现过程1、comm模块设计1.1 Log.hpp实现1.2 Util.hpp实现 2、compiler_server 模块设计2.1compile.hpp文件代码编写…

vb.netcad二开自学笔记2:认识vs编辑器

认识一下宇宙第一编辑器的界面图标含义还是很重要的&#xff0c;否则都不知道面对的是什么还怎么继续&#xff1f; 一、VS编辑器中常见的图标的含义 变量 长方体&#xff1a;变量 局部变量 两个矩形块&#xff1a;枚举 预定义的枚举 紫色立方体&#xff1a;方法 橙色树状结构…

通过AIS实现船舶追踪与照射

前些天突然接到个紧急的项目&#xff1a;某处需要实现对夜航船只进行追踪并用激光灯照射以保障夜航安全。这个项目紧急到什么程度呢&#xff1f;&#xff01;现场激光灯都安装好了&#xff0c;还有三个星期就要验收了&#xff0c;但上家没搞定就甩给我们了:( 从技术上看&#…

Java -- 实现MD5加密/加盐

目录 1. 加密的引出2. MD5介绍3. 解决MD5不可解密方法4. 实现加密解密4.1 加密4.2 验证密码 1. 加密的引出 在MySQL数据库中&#xff0c;一般都需要把密码、身份证、电话号码等信息进行加密&#xff0c;以确保数据的安全性。如果使用明文来存储&#xff0c;当数据库被入侵的时…

力扣考研经典题 反转链表

核心思想 头插法&#xff1a; 不断的将cur指针所指向的节点放到头节点之前&#xff0c;然后头节点指向cur节点&#xff0c;因为最后返回的是head.next 。 解题思路 1.如果头节点是空的&#xff0c;或者是只有一个节点&#xff0c;只需要返回head节点即可。 if (head null …

Vatee万腾平台:创新科技,驱动未来

在科技日新月异的今天&#xff0c;每一个创新的火花都可能成为推动社会进步的重要力量。Vatee万腾平台&#xff0c;作为科技创新领域的佼佼者&#xff0c;正以其卓越的技术实力、前瞻性的战略眼光和不懈的探索精神&#xff0c;驱动着未来的车轮滚滚向前。 Vatee万腾平台深知&am…

公有链、私有链与联盟链:区块链技术的多元化应用与比较

引言 区块链技术自2008年比特币白皮书发布以来&#xff0c;迅速发展成为一项具有颠覆性潜力的技术。区块链通过去中心化、不可篡改和透明的方式&#xff0c;提供了一种全新的数据存储和管理方式。起初&#xff0c;区块链主要应用于加密货币&#xff0c;如比特币和以太坊。然而&…

RUST 编程语言 绘制随机颜色图片 画圆形 画矩形 画直线

什么是Rust Rust是一种系统编程语言&#xff0c;旨在提供高性能和安全性。它是由Mozilla和其开发社区创建的开源语言&#xff0c;设计目标是在C的应用场景中提供一种现代、可靠和高效的选择。Rust的目标是成为一种通用编程语言&#xff0c;能够处理各种计算任务&#xff0c;包…

STM32-OC输出比较和PWM

本内容基于江协科技STM32视频内容&#xff0c;整理而得。 文章目录 1. OC输出比较和PWM1.1 OC输出比较1.2 PWM&#xff08;脉冲宽度调制&#xff09;1.3 输出比较通道&#xff08;高级&#xff09;1.4 输出比较通道&#xff08;通用&#xff09;1.5 输出比较模式1.6 PWM基本结…

数据库系统原理 | 查询作业2

整理自博主本科《数据库系统原理》专业课自己完成的实验课查询作业&#xff0c;以便各位学习数据库系统概论的小伙伴们参考、学习。 *文中若存在书写不合理的地方&#xff0c;欢迎各位斧正。 专业课本&#xff1a; ​ ​ ———— 本次实验使用到的图形化工具&#xff1a;Heidi…

ThreadPoolExecutor - 管理线程池的核心类

下面是使用给定的初始参数创建一个新的 ThreadPoolExecutor &#xff08;构造方法&#xff09;。 public ThreadPoolExecutor(int corePoolSize,int maximumPoolSize,long keepAliveTime,TimeUnit unit,BlockingQueue<Runnable> workQueue,ThreadFactory threadFactory,…

【SVN的使用-源代码管理工具-SVN介绍-服务器的搭建 Objective-C语言】

一、首先,我们来介绍一下源代码管理工具 1.源代码管理工具的起源 为什么会出现源代码管理工具,是为了解决源代码开发的过程中出现的很多问题: 1)无法后悔:把项目关了,无法Command + Z后悔, 2)版本备份:非空间、费时间、写的名称最后自己都忘了干什么的了, 3)版本…

中英双语介绍加拿大(Canada)

加拿大国家简介 中文版 加拿大简介 加拿大是位于北美洲北部的一个国家&#xff0c;以其广袤的土地、多样的文化和自然美景著称。以下是对加拿大的详细介绍&#xff0c;包括其地理位置、人口、经济、特色、高等教育、著名景点、国家历史和交通条件。 地理位置 加拿大是世界…

LeetCode 189.轮转数组 三段逆置 C写法

LeetCode 189.轮转数组 C写法 三段逆置 思路: 三段逆置方法:先逆置前n-k个 再逆置后k个 最后整体逆置 由示例1得&#xff0c;需要先逆置1,2,3,4 再逆置5,6,7&#xff0c;最后前n-k个与后k个逆置 代码 void reverse(int*num, int left, int right) //逆置函数 { while(left …

【工具】豆瓣自动回贴软件

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你&#xff0c;欢迎[点赞、收藏、关注]哦~ 相比于之前粗糙丑陋的黑命令框版本&#xff0c;这个版本新增了UI界面&#xff0c;从此可以不需要再挨个去翻配置文件了。 另外&#xff0c;升级了隐藏浏…

深入理解并发、线程与等待通知机制

目录 一、基础概念 进程和线程 进程 线程 Java 线程的无处不在 进程间的通信 进程间通信有几种方式&#xff1f; CPU 核心数和线程数的关系 上下文切换&#xff08;Context switch&#xff09; 并行和并发 二、认识 Java 里的线程 Java 程序天生就是多线程的 线程的…

使用Keil将STM32部分程序放在RAM中运行

手动分配RAM区域,新建.sct文件,定义RAM_CODE区域,并指定其正确的起始地址和大小。 ; ************************************************************* ; *** Scatter-Loading Description File generated by uVision *** ; ************************************************…

鸿蒙应用笔记

安装就跳过了&#xff0c;一直点点就可以了 配置跳过&#xff0c;就自动下了点东西。 鸿蒙那个下载要12g个内存&#xff0c;大的有点吓人。 里面跟idea没区别 模拟器或者真机运行 真机要鸿蒙4.0&#xff0c;就可以实机调试 直接在手机里面跑&#xff0c;这个牛逼&#xf…