pytorch简洁实现MNIST识别
可以看这里,比较与tensorflow版本的区别在输入格式方面,pytorch是NCHW,tensorflow是NHWC网络返回log_softmax时,应该使用nll_lossMyData.py
import os
import cv2
import random
import numpy as np
class Dataset(object):
def __init__(self, dataset_path, train, batch_size=1):
self.all = []
for line in open(dataset_path):
self.all.append(line)
if train == True:
random.shuffle(self.all)
self.bs
共有 0 条评论