nn---Neural network torch.nn里面有神经网络的一些工具
torch.nn — PyTorch 1.11.0 documentationhttps://pytorch.org/docs/stable/nn.html
-
Containers 骨架,给神经网络定义了结构,往这些结构之中添加不同的内容就可以构成我们的神经网络
-
下面就是可以往神经网络里添加的一些 *** 作
-
Convolution Layers 卷积层
-
Pooling layers 池化层
-
Padding Layers
-
Non-linear Activations (weighted sum, nonlinearity) 非线性激活
-
Non-linear Activations (other)
-
Normalization Layers 正则化层
-
Recurrent Layers
-
Transformer Layers
-
Linear Layers
-
Dropout Layers
-
Sparse Layers
-
Distance Functions
-
Loss Functions
-
Vision Layers
-
Shuffle Layers
-
DataParallel Layers (multi-GPU, distributed)
-
Utilities
-
Quantized Functions
-
Lazy Modules Initialization
基本构成模块
Module | Base class for all neural network modules. 常用 对于所有神经网络一个基本的类 |
Sequential | A sequential container. |
ModuleList | Holds submodules in a list. |
ModuleDict | Holds submodules in a dictionary. |
ParameterList | Holds parameters in a list. |
ParameterDict | Holds parameters in a dictionary. |
Module简单的使用
class Tudui(nn.Module):#继承nn.Module 注意 Module 首字母大写!!!
#初始化
def __init__(self) -> None:#可以直接写 也可以alt+insert选择重写__init__
super().__init__()
#forward 定义每次调用时执行的计算,应该被所有子类重写
def forward(self,input):
output=input+1
return output
tudui=Tudui()
x=torch.tensor(1.0)
output=tudui(x)
print(output)
2.Convolution Layers 卷积层
nn.Conv1d | Applies a 1D convolution over an input signal composed of several input planes. 一维卷积 |
nn.Conv2d | Applies a 2D convolution over an input signal composed of several input planes.二维卷积 |
Parameters
(2)演示TORCH.NN.FUNCTIONAL.CONV2
演示此过程
import torch
import torch.nn.functional as F
input=torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]])#二维矩阵,看中括号数
#卷积核
kernel=torch.tensor([[1,2,1],
[0,1,0],
[2,1,0]])
#conv2d 要求输入四个数据,我们只有两个高和宽,需要变换尺寸
input=torch.reshape(input,(1,1,5,5))#通道数是1,batch的大小是1,数据维度是5*5
kernel=torch.reshape(kernel,(1,1,3,3))
print(input.shape)
print(kernel.shape)
output=F.conv2d(input,kernel,stride=1)
print(output)
output2=F.conv2d(input,kernel,stride=2)
print(output2)
output3=F.conv2d(input,kernel,stride=1,padding=1)#padding=1
print(output3)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)