信息发布→ 登录 注册 退出

PyTorch实现线性回归详细过程

发布时间:2026-01-11

点击量:
目录
  • 一、实现步骤
    • 1、准备数据
    • 2、设计模型
    • 3、构造损失函数和优化器
    • 4、训练过程
    • 5、结果展示
  • 二、参考文献

    一、实现步骤

    1、准备数据

    x_data = torch.tensor([[1.0],[2.0],[3.0]])
    y_data = torch.tensor([[2.0],[4.0],[6.0]])

    2、设计模型

    class LinearModel(torch.nn.Module):
        def __init__(self):
            super(LinearModel,self).__init__()
            self.linear = torch.nn.Linear(1,1)
            
        def forward(self, x):
            y_pred = self.linear(x)
            return y_pred
            
    model = LinearModel()  

    3、构造损失函数和优化器

    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

    4、训练过程

    epoch_list = []
    loss_list = []
    w_list = []
    b_list = []
    for epoch in range(1000):
        y_pred = model(x_data)                      # 计算预测值
        loss = criterion(y_pred, y_data)    # 计算损失
        print(epoch,loss)
        
        epoch_list.append(epoch)
        loss_list.append(loss.data.item())
        w_list.append(model.linear.weight.item())
        b_list.append(model.linear.bias.item())
        
        optimizer.zero_grad()   # 梯度归零
        loss.backward()         # 反向传播
        optimizer.step()        # 更新

    5、结果展示

    展示最终的权重和偏置:

    # 输出权重和偏置
    print('w = ',model.linear.weight.item())
    print('b = ',model.linear.bias.item())

    结果为:

    w =  1.9998501539230347
    b =  0.0003405189490877092

    模型测试:

    # 测试模型
    x_test = torch.tensor([[4.0]])
    y_test = model(x_test)
    print('y_pred = ',y_test.data)
    
    y_pred =  tensor([[7.9997]])

    分别绘制损失值随迭代次数变化的二维曲线图和其随权重与偏置变化的三维散点图:

    # 二维曲线图
    plt.plot(epoch_list,loss_list,'b')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()
    
    # 三维散点图
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(w_list,b_list,loss_list,c='r')
    #设置坐标轴
    ax.set_xlabel('weight')
    ax.set_ylabel('bias')
    ax.set_zlabel('loss')
    plt.show()

    结果如下图所示:

     到此这篇关于PyTorch实现线性回归详细过程的文章就介绍到这了,更多相关PyTorch线性回归内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!

    二、参考文献

    • [1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=5
    在线客服
    服务热线

    服务热线

    4008888355

    微信咨询
    二维码
    返回顶部
    ×二维码

    截屏,微信识别二维码

    打开微信

    微信号已复制,请打开微信添加咨询详情!