FastSpeech2代码理解之模型实现(一)

FastSpeech2代码理解之模型实现(一),第1张

文章目录
  • 参考
  • FastSpeech2
    • 1. FastSpeech2实现
    • 2. 模型结构
  • FastSpeech2各模块实现
    • 1. Encoder
    • 2. Decoder
    • 3. Variance Adaptor


参考

参考项目:FastSpeech2的github实现
FastSpeech2论文
FastSpeech2模型代码分析

FastSpeech2

FastSpeech2是一个基于Transformer的端到端语音合成模型,其结构如下:


Encoder将音素序列转换到隐藏序列,然后Variance Adaptor将不同的变量信息,如时长、音高、能量加入到到隐藏序列中,最终解码器将隐藏序列转换为梅尔谱序列。


1. FastSpeech2实现

FastSpeech2的实现位于/model/fastspeech2.py中:

class FastSpeech2(nn.Module):
    """ FastSpeech2 """

    def __init__(self, preprocess_config, model_config):
        super(FastSpeech2, self).__init__()
        self.model_config = model_config

        self.encoder = Encoder(model_config)
        self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config)
        self.decoder = Decoder(model_config)
        self.mel_linear = nn.Linear(
            model_config["transformer"]["decoder_hidden"],
            preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
        )
        self.postnet = PostNet()

        self.speaker_emb = None
        if model_config["multi_speaker"]:
            with open(
                os.path.join(
                    preprocess_config["path"]["preprocessed_path"], "speakers.json"
                ),
                "r",
            ) as f:
                n_speaker = len(json.load(f))
            self.speaker_emb = nn.Embedding(
                n_speaker,
                model_config["transformer"]["encoder_hidden"],
            )
2. 模型结构

使用如下代码打印参考项目的FastSpeech2模型结构

from model import FastSpeech2
from utils.tools import get_configs_of

preprocess_config, model_config, train_config = get_configs_of("AISHELL3")

print(FastSpeech2(preprocess_config, model_config))

其中get_configs_of是原参考项目没有的,在/utils/tools.py中增加如下代码

import yaml

def get_configs_of(dataset):
    config_dir = os.path.join("./config", dataset)
    preprocess_config = yaml.load(open(
        os.path.join(config_dir, "preprocess.yaml"), "r"), Loader=yaml.FullLoader)
    model_config = yaml.load(open(
        os.path.join(config_dir, "model.yaml"), "r"), Loader=yaml.FullLoader)
    train_config = yaml.load(open(
        os.path.join(config_dir, "train.yaml"), "r"), Loader=yaml.FullLoader)
    return preprocess_config, model_config, train_config

FastSpeech2结构打印如下,其中postnet模块在FastSpeech2的论文中是没有的,是此参考项目的作者增加的:

FastSpeech2(
  (encoder): Encoder(
    (src_word_emb): Embedding(361, 256, padding_idx=0)
    (layer_stack): ModuleList(
      (0): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (1): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (2): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (3): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
  )
  (variance_adaptor): VarianceAdaptor(
    (duration_predictor): VariancePredictor(
      (conv_layer): Sequential(
        (conv1d_1): Conv(
          (conv): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        )
        (relu_1): ReLU()
        (layer_norm_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout_1): Dropout(p=0.5, inplace=False)
        (conv1d_2): Conv(
          (conv): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        )
        (relu_2): ReLU()
        (layer_norm_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout_2): Dropout(p=0.5, inplace=False)
      )
      (linear_layer): Linear(in_features=256, out_features=1, bias=True)
    )
    (length_regulator): LengthRegulator()
    (pitch_predictor): VariancePredictor(
      (conv_layer): Sequential(
        (conv1d_1): Conv(
          (conv): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        )
        (relu_1): ReLU()
        (layer_norm_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout_1): Dropout(p=0.5, inplace=False)
        (conv1d_2): Conv(
          (conv): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        )
        (relu_2): ReLU()
        (layer_norm_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout_2): Dropout(p=0.5, inplace=False)
      )
      (linear_layer): Linear(in_features=256, out_features=1, bias=True)
    )
    (energy_predictor): VariancePredictor(
      (conv_layer): Sequential(
        (conv1d_1): Conv(
          (conv): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        )
        (relu_1): ReLU()
        (layer_norm_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout_1): Dropout(p=0.5, inplace=False)
        (conv1d_2): Conv(
          (conv): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        )
        (relu_2): ReLU()
        (layer_norm_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout_2): Dropout(p=0.5, inplace=False)
      )
      (linear_layer): Linear(in_features=256, out_features=1, bias=True)
    )
    (pitch_embedding): Embedding(256, 256)
    (energy_embedding): Embedding(256, 256)
  )
  (decoder): Decoder(
    (layer_stack): ModuleList(
      (0): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (1): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (2): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (3): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (4): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (5): FFTBlock(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=True)
          (w_ks): Linear(in_features=256, out_features=256, bias=True)
          (w_vs): Linear(in_features=256, out_features=256, bias=True)
          (attention): ScaledDotProductAttention(
            (softmax): Softmax(dim=2)
          )
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (pos_ffn): PositionwiseFeedForward(
          (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
          (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
  )
  (mel_linear): Linear(in_features=256, out_features=80, bias=True)
  (postnet): PostNet(
    (convolutions): ModuleList(
      (0): Sequential(
        (0): ConvNorm(
          (conv): Conv1d(80, 512, kernel_size=(5,), stride=(1,), padding=(2,))
        )
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): Sequential(
        (0): ConvNorm(
          (conv): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,))
        )
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): Sequential(
        (0): ConvNorm(
          (conv): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,))
        )
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): Sequential(
        (0): ConvNorm(
          (conv): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,))
        )
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (4): Sequential(
        (0): ConvNorm(
          (conv): Conv1d(512, 80, kernel_size=(5,), stride=(1,), padding=(2,))
        )
        (1): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (speaker_emb): Embedding(218, 256)
)

FastSpeech2各模块实现

FastSpeech2使用前馈形的Transformer块FFTBlock作为encoder和mel-spectrogram decoder的基础结构,FFTBlock是由自注意力层slf_attn和前馈神经网络pos_ffn组成的。


FFTBlock的实现位于/transformer/Layers.py,代码如下:

class FFTBlock(torch.nn.Module):
    """FFT Block"""

    def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
        super(FFTBlock, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(
            d_model, d_inner, kernel_size, dropout=dropout
        )

    def forward(self, enc_input, mask=None, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask
        )
        enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)

        enc_output = self.pos_ffn(enc_output)
        enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)

        return enc_output, enc_slf_attn
1. Encoder

Encoder主要由词嵌入src_word_emb后接4个FFTBlock组成,其实现位于/transformer/Models.py。


class Encoder(nn.Module):
    """ Encoder """

    def __init__(self, config):
        super(Encoder, self).__init__()

        n_position = config["max_seq_len"] + 1
        n_src_vocab = len(symbols) + 1
        d_word_vec = config["transformer"]["encoder_hidden"]
        n_layers = config["transformer"]["encoder_layer"]
        n_head = config["transformer"]["encoder_head"]
        d_k = d_v = (
            config["transformer"]["encoder_hidden"]
            // config["transformer"]["encoder_head"]
        )
        d_model = config["transformer"]["encoder_hidden"]
        d_inner = config["transformer"]["conv_filter_size"]
        kernel_size = config["transformer"]["conv_kernel_size"]
        dropout = config["transformer"]["encoder_dropout"]

        self.max_seq_len = config["max_seq_len"]
        self.d_model = d_model
		
		#词嵌入
        self.src_word_emb = nn.Embedding(
            n_src_vocab, d_word_vec, padding_idx=Constants.PAD
        )
        #位置编码
        self.position_enc = nn.Parameter(
            get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
            requires_grad=False,
        )
        
		#4个FTTBlock
        self.layer_stack = nn.ModuleList(
            [
                FFTBlock(
                    d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
                )
                for _ in range(n_layers)
            ]
        )
2. Decoder

Decoder主要由6个FFTBlock组成,其实现位于/transformer/Models.py。


class Decoder(nn.Module):
    """ Decoder """

    def __init__(self, config):
        super(Decoder, self).__init__()

        n_position = config["max_seq_len"] + 1
        d_word_vec = config["transformer"]["decoder_hidden"]
        n_layers = config["transformer"]["decoder_layer"]
        n_head = config["transformer"]["decoder_head"]
        d_k = d_v = (
            config["transformer"]["decoder_hidden"]
            // config["transformer"]["decoder_head"]
        )
        d_model = config["transformer"]["decoder_hidden"]
        d_inner = config["transformer"]["conv_filter_size"]
        kernel_size = config["transformer"]["conv_kernel_size"]
        dropout = config["transformer"]["decoder_dropout"]

        self.max_seq_len = config["max_seq_len"]
        self.d_model = d_model

        self.position_enc = nn.Parameter(
            get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
            requires_grad=False,
        )
		
		#6个FFTBlock
        self.layer_stack = nn.ModuleList(
            [
                FFTBlock(
                    d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
                )
                for _ in range(n_layers)
            ]
        )
3. Variance Adaptor

Variance Adaptor的结构如图b所示,由时长预测器duration_predictor、音高预测器pitch_predictor和能量预测器energy_predictor组成,每个predictor的结构都一样,如图c所示。




Variance Adaptor的实现位于/model/modules.py,代码如下:

class VarianceAdaptor(nn.Module):
    """Variance Adaptor"""

    def __init__(self, preprocess_config, model_config):
        super(VarianceAdaptor, self).__init__()
        self.duration_predictor = VariancePredictor(model_config)
        self.length_regulator = LengthRegulator()
        self.pitch_predictor = VariancePredictor(model_config)
        self.energy_predictor = VariancePredictor(model_config)

        self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
            "feature"
        ]
        self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
            "feature"
        ]
        assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
        assert self.energy_feature_level in ["phoneme_level", "frame_level"]

        pitch_quantization = model_config["variance_embedding"]["pitch_quantization"]
        energy_quantization = model_config["variance_embedding"]["energy_quantization"]
        n_bins = model_config["variance_embedding"]["n_bins"]
        assert pitch_quantization in ["linear", "log"]
        assert energy_quantization in ["linear", "log"]
        with open(
            os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
        ) as f:
            stats = json.load(f)
            pitch_min, pitch_max = stats["pitch"][:2]
            energy_min, energy_max = stats["energy"][:2]

        if pitch_quantization == "log":
            self.pitch_bins = nn.Parameter(
                torch.exp(
                    torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
                ),
                requires_grad=False,
            )
        else:
            self.pitch_bins = nn.Parameter(
                torch.linspace(pitch_min, pitch_max, n_bins - 1),
                requires_grad=False,
            )
        if energy_quantization == "log":
            self.energy_bins = nn.Parameter(
                torch.exp(
                    torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
                ),
                requires_grad=False,
            )
        else:
            self.energy_bins = nn.Parameter(
                torch.linspace(energy_min, energy_max, n_bins - 1),
                requires_grad=False,
            )

        self.pitch_embedding = nn.Embedding(
            n_bins, model_config["transformer"]["encoder_hidden"]
        )
        self.energy_embedding = nn.Embedding(
            n_bins, model_config["transformer"]["encoder_hidden"]
        )

Variance Predictor的实现位于/model/modules.py,代码如下:

class VariancePredictor(nn.Module):
    """Duration, Pitch and Energy Predictor"""

    def __init__(self, model_config):
        super(VariancePredictor, self).__init__()

        self.input_size = model_config["transformer"]["encoder_hidden"]
        self.filter_size = model_config["variance_predictor"]["filter_size"]
        self.kernel = model_config["variance_predictor"]["kernel_size"]
        self.conv_output_size = model_config["variance_predictor"]["filter_size"]
        self.dropout = model_config["variance_predictor"]["dropout"]

        self.conv_layer = nn.Sequential(
            OrderedDict(
                [
                    (
                        "conv1d_1",
                        Conv(
                            self.input_size,
                            self.filter_size,
                            kernel_size=self.kernel,
                            padding=(self.kernel - 1) // 2,
                        ),
                    ),
                    ("relu_1", nn.ReLU()),
                    ("layer_norm_1", nn.LayerNorm(self.filter_size)),
                    ("dropout_1", nn.Dropout(self.dropout)),
                    (
                        "conv1d_2",
                        Conv(
                            self.filter_size,
                            self.filter_size,
                            kernel_size=self.kernel,
                            padding=1,
                        ),
                    ),
                    ("relu_2", nn.ReLU()),
                    ("layer_norm_2", nn.LayerNorm(self.filter_size)),
                    ("dropout_2", nn.Dropout(self.dropout)),
                ]
            )
        )

        self.linear_layer = nn.Linear(self.conv_output_size, 1)

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/569473.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存