- 1.ShuffleNetV1文章
- 2.重新搭建ShuffleNetV1模型结构(便于训练)
- (1)关于数据集和训练的整个代码可以参考我这篇文章
- (2)改换代码中的网络结构
- 3.训练结果
- 4.图像测试
https://mydreamambitious.blog.csdn.net/article/details/124675932
2.重新搭建ShuffleNetV1模型结构(便于训练) (1)关于数据集和训练的整个代码可以参考我这篇文章
https://mydreamambitious.blog.csdn.net/article/details/123966676
(2)改换代码中的网络结构注:只要将代码中的网络结构换成下面的这个ShuffleNetV1结构即可(如果采用的是面向对象的方式设计的网络结构的话,在训练数据的时候需要做一些调整)并且再将代码中的一些名称修改一下即可。
需要修改的名称如下:
# 通道重排
def Shuffle_Channels(inputs, groups):
# 获取图片batch,高度,宽度,通道数
batch, h, w, c = inputs.shape
# 对通道进行变形
input_reshape = tf.reshape(inputs, [-1, h, w, groups, c // groups])
# 退通道进行转置
input_transposed = tf.transpose(input_reshape, [0, 1, 2, 4, 3])
# 再对通道进行变形
output_reshape = tf.reshape(input_transposed, [-1, h, w, c])
return output_reshape
#分组卷积
def group_conv(inputs,filters,kernel,strides,groups):
#对通道进行切分
conv_num_layers=tf.split(inputs,groups,axis=3)
conv_side_layers=[]
for layer in conv_num_layers:
conv_side_layers.append(
layers.Conv2D(filters//groups,kernel_size=kernel,strides=strides,padding='same')(layer)
)
x_out=tf.concat(
conv_side_layers,axis=-1
)
return x_out
def ShuffleNetUnitB(inputs, filter, groups):
x=group_conv(inputs,filters=filter,kernel=[1,1],strides=[1,1],groups=groups)
# x = layers.Conv2D(filter // 4, kernel_size=[1, 1], strides=[1, 1], padding='same')(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = Shuffle_Channels(x, groups)
x = layers.DepthwiseConv2D(kernel_size=[3, 3], strides=[1, 1], padding='same')(x)
x = layers.BatchNormalization()(x)
x = group_conv(x, filters=filter, kernel=[1, 1], strides=[1, 1], groups=groups)
# x = layers.Conv2D(filter, kernel_size=[1, 1], strides=[1, 1], groups=groups, padding='same')(x)
x = layers.BatchNormalization()(x)
x_out = tf.add(inputs, x)
x_out = layers.Activation('relu')(x_out)
return x_out
def ShuffleNetUnitC(inputs, filter, inputFilter, groups):
outputChannels = filter - inputFilter
x = group_conv(inputs, filters=filter, kernel=[1, 1], strides=[1, 1], groups=groups)
# x = layers.Conv2D(filter // 4, kernel_size=[1, 1], strides=[1, 1], padding='same')(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = Shuffle_Channels(x, groups)
x = layers.DepthwiseConv2D(kernel_size=[3, 3], strides=[2, 2], padding='same')(x)
x = layers.BatchNormalization()(x)
x = group_conv(x, filters=outputChannels, kernel=[1, 1], strides=[1, 1], groups=groups)
# x = layers.Conv2D(outputChannels, kernel_size=[1, 1], strides=[1, 1], padding='same')(x)
x = layers.BatchNormalization()(x)
x_avgpool3x3 = layers.AveragePooling2D(pool_size=[3, 3], strides=[2, 2], padding='same')(inputs)
# 对channel进行堆叠
x_out = tf.concat([x, x_avgpool3x3],
axis=3)
x_out = layers.Activation('relu')(x_out)
return x_out
def ShuffleNetV1Model(numlayers, stage2Filter, stage3Filter, stage4Filter, groups,input_shape=(224,224,3)):
input = keras.Input(input_shape)
x = layers.Conv2D(24, kernel_size=[3, 3], strides=[2, 2], padding='same')(input)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.MaxPool2D(pool_size=[3, 3], strides=[2, 2], padding='same')(x)
def StageX(inputs, numlayers, filter, inputFilter,groups):
x_out=ShuffleNetUnitC(inputs, filter, inputFilter, groups)
for i in range(numlayers):
x_out=ShuffleNetUnitB(x_out, filter, groups)
return x_out
x = StageX(x, numlayers[0], stage2Filter, inputFilter=24, groups=groups)
x = StageX(x, numlayers[1], stage3Filter, inputFilter=240, groups=groups)
x = StageX(x, numlayers[2], stage4Filter, inputFilter=480, groups=groups)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(2)(x)
x_out = layers.Activation('softmax')(x)
model = Model(inputs=input, outputs=x_out)
return model
model_shufflenetv1 = ShuffleNetV1Model(numlayers=[3, 7, 3], stage2Filter=240, stage3Filter=480,
stage4Filter=960, groups=3,input_shape=(224, 224, 3))
model_shufflenetv1.summary()
……
3.训练结果
由于数据集不是很多,只有近1000张图片,并且之包含两个类别(猫-cat,狗-dot)所以训练的最后效果并不怎么样(中间还需要调参),读者可以根据自己的需要收集图片进行训练。
只需要将flask中加载的权重文件换成自己训练的就可以了(flask这部分代码在从Github下载的文件中)。
https://github.com/KeepTryingTo/-.git
flask.py文件:
前端文件:index.html
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)