用户登录
用户注册

分享至

d2l.train_ch3函数,将输入数据转化为该函数所接受的格式

  • 作者: 负面情绪奶豆
  • 来源: 51数据库
  • 2022-05-02

? 目的:使用d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None,None, trainer)这个函数进行softmax回归计算

? 问题:使用非官方的数据报错

? 解决方法:将输入数据转为函数要求的格式

? 初始数据:

?

? 原先的数据是列表,通过nd.array的方式将熟悉的列表数据转为张量

from mxnet import autograd, nd

data  = nd.array(data)

? 然后通过data_iter函数,生成随机的小批量样本

代码

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)  # 样本的读取顺序是随机的  indices列表随机排列
    for i in range(0, num_examples, batch_size):
        j = nd.array(indices[i: min(i + batch_size, num_examples)])
        yield features.take(j), labels.take(j)  # take函数根据索引返回对应元素  最后结果为生成器

train_iter = list(data_iter(batch_size, image_data_all[0:1000], image_label_all[0:1000]))
test_iter = list(data_iter(batch_size, image_data_all[1000:1800], image_label_all[1000:1800]))

? ?这时将数据带入仍会报错,通过list(data_iter(...))的方式将生成器转化为列表,其中列表里面有张量,通过for循环我们可以得到我们需要 的张量

num_epochs = 5
d2l.train_ch3(net,train_iter, test_iter,loss, num_epochs, batch_size, None,
              None, trainer)

结果:
epoch 1, loss 1.1367, train acc 0.997, test acc 0.839
epoch 2, loss 0.7242, train acc 0.999, test acc 0.845
epoch 3, loss 0.6524, train acc 0.999, test acc 0.841
epoch 4, loss 0.5581, train acc 0.999, test acc 0.839
epoch 5, loss 0.0000, train acc 1.000, test acc 0.839

? 成功!!!

软件
前端设计
程序设计
Java相关