博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
gluon 实现多层感知机MLP分类FashionMNIST
阅读量:4963 次
发布时间:2019-06-12

本文共 1970 字,大约阅读时间需要 6 分钟。

from mxnet import gluon,initfrom mxnet.gluon import loss as gloss, nnfrom mxnet.gluon import data as gdatafrom mxnet import nd,autogradimport gluonbook as gbimport sys# 读取数据# 读取数据mnist_train = gdata.vision.FashionMNIST(train=True)mnist_test = gdata.vision.FashionMNIST(train=False)batch_size = 256transformer = gdata.vision.transforms.ToTensor()if sys.platform.startswith('win'):    num_workers = 0else:    num_workers = 4# 小批量数据迭代器train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),batch_size=batch_size,shuffle=True,num_workers=num_workers)test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),batch_size=batch_size,shuffle=False,num_workers=num_workers)# 定义网络net = nn.Sequential()net.add(nn.Dense(256,activation='relu'),nn.Dense(10))net.initialize(init.Normal(sigma=0.01))# 损失函数loss = gloss.SoftmaxCrossEntropyLoss()trainer = gluon.Trainer(net.collect_params(),'sgd',{
'learning_rate':0.5})def accuracy(y_hat, y): return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar()def evaluate_accuracy(data_iter, net): acc = 0 for X, y in data_iter: acc += accuracy(net(X), y) return acc / len(data_iter)num_epochs = 5def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,trainer=None): for epoch in range(num_epochs): train_l_sum = 0 train_acc_sum = 0 for X,y in train_iter: with autograd.record(): y_hat = net(X) l = loss(y_hat,y) l.backward() if trainer is None: gb.sgd(params,lr,batch_size) else: trainer.step(batch_size) train_l_sum += l.mean().asscalar() test_acc = evaluate_accuracy(test_iter,net) print('epoch %d,loss %.4f,test acc %.3f'%(epoch+1,train_l_sum / len(train_iter),test_acc))train(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,trainer)

转载于:https://www.cnblogs.com/TreeDream/p/10033557.html

你可能感兴趣的文章
flask简单的注册功能
查看>>
JSP常用标签
查看>>
dashucoding记录2019.6.7
查看>>
IOS FMDB
查看>>
编码总结,以及对BOM的理解
查看>>
九涯的第一次
查看>>
PHP5.3的VC9、VC6、Thread Safe、Non Thread Safe的区别
查看>>
Android中全屏或者取消标题栏
查看>>
处理器管理与进程调度
查看>>
页面懒加载
查看>>
向量非零元素个数_向量范数详解+代码实现
查看>>
java zip 中文文件名乱码_java使用zip压缩中文文件名乱码的解决办法
查看>>
java if 用法详解_Java编程中的条件判断之if语句的用法详解
查看>>
kafka的java客户端_KAFKA Producer java客户端示例
查看>>
java -f_java学习笔记(一)
查看>>
java 什么题目好做_用java做这些题目
查看>>
java中的合同打印_比较方法违反了Java 7中的一般合同
查看>>
php 位运算与权限,怎么在PHP中使用位运算对网站的权限进行管理
查看>>
php include效率,php include类文件超时
查看>>
matlab sin函数 fft,matlab的fft函数的使用教程
查看>>