Transformer课程:理解语言的 Transformer 模型

Transformer课程:理解语言的 Transformer 模型,第1张

Transformer课程:理解语言的 Transformer 模型 点积注意力(Scaled dot product attention)

Transformer 使用的注意力函数有三个输入:Q(请求(query))、K(主键(key))、V(数值(value))。用于计算注意力权重的等式为:

A t t e n t i o n ( Q , K , V ) = s o f t m a x k ( Q K T d k ) V Large{Attention(Q, K, V) = softmax_k(frac{QK^T}{sqrt{d_k}}) V} Attention(Q,K,V)=softmaxk​(dk​ ​QKT​)V

点积注意力缩小了模型维度的平方根倍。假设 Q 和 K 的均值为0,方差为1。它们的矩阵乘积将有均值为0,方差为 dk。因此,dk 的平方根被用于缩放 ,因为Q 和 K 的矩阵乘积的均值本应该为 0,方差本应该为1,这样会获得一个更平缓的 softmax。

掩码(mask)与 -1e9(接近于负无穷)相乘,掩码与缩放的 Q 和 K 的矩阵乘积相加,并在 softmax 之前应用。目标是将这些单元归零,因为 softmax 的较大负数输入在输出中接近于零。

参考网上的一个例子:在Transformer 的 Decoder中, 一个包括四个词的句子[A,B,C,D]在计算了相似度scores之后,得到下面第一幅图,将scores的上三角区域mask掉,即替换为负无穷,再做softmax得到第三幅图。这样,比如输入 B 在self-attention之后,也只和A,B有关,而与后序信息无关。因为在softmax之后的加权平均中: B’ = 0.48A+0.52B,而 C,D 对 B’不做贡献。

def scaled_dot_product_attention(q, k, v, mask):
  """计算注意力权重。
  q, k, v 必须具有匹配的前置维度。
  k, v 必须有匹配的倒数第二个维度,例如:seq_len_k = seq_len_v。
  虽然 mask 根据其类型(填充或前瞻)有不同的形状,
  但是 mask 必须能进行广播转换以便求和。
  
  参数:
    q: 请求的形状 == (..., seq_len_q, depth)
    k: 主键的形状 == (..., seq_len_k, depth)
    v: 数值的形状 == (..., seq_len_v, depth_v)
    mask: Float 张量,其形状能转换成
          (..., seq_len_q, seq_len_k)。默认为None。
    
  返回值:
    输出,注意力权重
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
  
  # 缩放 matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # 将 mask 加入到缩放的张量上。
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax 在最后一个轴(seq_len_k)上归一化,因此分数
  # 相加等于1。
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights

当 softmax 在 K 上进行归一化后,它的值决定了分配到 Q 的重要程度。输出表示注意力权重和 V(数值)向量的乘积。这确保了要关注的词保持原样,而无关的词将被清除掉。

def print_out(q, k, v):
  temp_out, temp_attn = scaled_dot_product_attention(
      q, k, v, None)
  print ('Attention weights are:')
  print (temp_attn)
  print ('Output is:')
  print (temp_out)
np.set_printoptions(suppress=True)

temp_k = tf.constant([[10,0,0],
                      [0,10,0],
                      [0,0,10],
                      [0,0,10]], dtype=tf.float32)  # (4, 3)

temp_v = tf.constant([[   1,0],
                      [  10,0],
                      [ 100,5],
                      [1000,6]], dtype=tf.float32)  # (4, 2)

# 这条 `请求(query)符合第二个`主键(key)`,
# 因此返回了第二个`数值(value)`。
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32)  
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[10.  0.]], shape=(1, 2), dtype=float32)
# 这条请求符合重复出现的主键(第三第四个),
# 因此,对所有的相关数值取了平均。
temp_q = tf.constant([[0, 0, 10]], dtype=tf.float32)   
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0.  0.  0.5 0.5]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[550.    5.5]], shape=(1, 2), dtype=float32)
# 这条请求符合第一和第二条主键,
# 因此,对它们的数值去了平均。
temp_q = tf.constant([[10, 10, 0]], dtype=tf.float32)   
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0.5 0.5 0.  0. ]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[5.5 0. ]], shape=(1, 2), dtype=float32)

将所有请求一起传递

temp_q = tf.constant([[0, 0, 10], [0, 10, 0], [10, 10, 0]], dtype=tf.float32)  # (3, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor(
[[0.  0.  0.5 0.5]
 [0.  1.  0.  0. ]
 [0.5 0.5 0.  0. ]], shape=(3, 4), dtype=float32)
Output is:
tf.Tensor(
[[550.    5.5]
 [ 10.    0. ]
 [  5.5   0. ]], shape=(3, 2), dtype=float32)
多头注意力(Multi-head attention)

多头注意力由四部分组成:

  • 线性层并分拆成多头。
  • 按比缩放的点积注意力。
  • 多头及联。
  • 最后一层线性层。

每个多头注意力块有三个输入:Q(请求)、K(主键)、V(数值)。这些输入经过线性(Dense)层,并分拆成多头。

将上面定义的 scaled_dot_product_attention 函数应用于每个头( broadcasted 以提高效率)。注意力这步必须使用一个恰当的 mask。然后将每个头的注意力输出连接起来(用tf.transpose 和 tf.reshape),并放入 Dense 层。

Q、K、和 V 被拆分到了多个头,而非单个的注意力头,因为多头允许模型共同注意来自不同表示空间的不同位置的信息。在分拆后,每个头部的维度减少,因此总的计算成本与有着全部维度的单个注意力头相同。

class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model
    
    assert d_model % self.num_heads == 0
    
    self.depth = d_model // self.num_heads
    
    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)
    
    self.dense = tf.keras.layers.Dense(d_model)
        
  def split_heads(self, x, batch_size):
    """分拆最后一个维度到 (num_heads, depth).
    转置结果使得形状为 (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])
    
  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]
    
    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)
    
    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
    
    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)
    
    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
        
    return output, attention_weights

创建一个 MultiHeadAttention 层进行尝试。在序列中的每个位置 y,MultiHeadAttention 在序列中的所有其他位置运行所有8个注意力头,在每个位置y,返回一个新的同样长度的向量。

temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 8, 60, 60]))
前馈网络(Point wise feed forward network)

前馈网络由两层全联接层组成,两层之间有一个 ReLU 激活函数。

def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])
sample_ffn = point_wise_feed_forward_network(512, 2048)
sample_ffn(tf.random.uniform((64, 50, 512))).shape
TensorShape([64, 50, 512])
编码与解码(Encoder and decoder)

Transformer 模型与标准的具有注意力机制的序列到序列模型(sequence to sequence with attention model),遵循相同的一般模式。

  • 输入语句经过 N 个编码器层,为序列中的每个词/标记生成一个输出。
  • 解码器关注编码器的输出以及它自身的输入(自注意力)来预测下一个词。
编码器层(Encoder layer)

每个编码器层包括以下子层:

  1. 多头注意力(有填充遮挡)
  2. 前馈网络(Point wise feed forward networks)。

每个子层在其周围有一个残差连接,然后进行层归一化。残差连接有助于避免深度网络中的梯度消失问题。

每个子层的输出是 LayerNorm(x + Sublayer(x))。归一化是在 d_model 维度完成的。Transformer 中有 N 个编码器层。

class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(EncoderLayer, self).__init__()

    self.mha = MultiHeadAttention(d_model, num_heads)
    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    
    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    
  def call(self, x, training, mask):

    attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)
    
    ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
    ffn_output = self.dropout2(ffn_output, training=training)
    out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)
    
    return out2
sample_encoder_layer = EncoderLayer(512, 8, 2048)

sample_encoder_layer_output = sample_encoder_layer(
    tf.random.uniform((64, 43, 512)), False, None)

sample_encoder_layer_output.shape  # (batch_size, input_seq_len, d_model)
TensorShape([64, 43, 512])
解码器层(Decoder layer)

每个解码器层包括以下子层:

  1. 遮挡的多头注意力(前瞻遮挡和填充遮挡)
  2. 多头注意力(用填充遮挡)。V(数值)和 K(主键)接收编码器输出作为输入。Q(请求)接收遮挡的多头注意力子层的输出
  3. 前馈网络

每个子层在其周围有一个残差连接,然后进行层归一化。每个子层的输出是 LayerNorm(x + Sublayer(x))。归一化是在 d_model 维度完成的。

Transformer 中共有 N 个解码器层。

当 Q 接收到解码器的第一个注意力块的输出,并且 K 接收到编码器的输出时,注意力权重表示根据编码器的输出赋予解码器输入的重要性。换一种说法,解码器通过查看编码器输出和对其自身输出的自注意力,预测下一个词。参看按比缩放的点积注意力部分的演示。

class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(DecoderLayer, self).__init__()

    self.mha1 = MultiHeadAttention(d_model, num_heads)
    self.mha2 = MultiHeadAttention(d_model, num_heads)

    self.ffn = point_wise_feed_forward_network(d_model, dff)
 
    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    
    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)
    
    
  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):
    # enc_output.shape == (batch_size, input_seq_len, d_model)

    attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
    attn1 = self.dropout1(attn1, training=training)
    out1 = self.layernorm1(attn1 + x)
    
    attn2, attn_weights_block2 = self.mha2(
        enc_output, enc_output, out1, padding_mask)  # (batch_size, target_seq_len, d_model)
    attn2 = self.dropout2(attn2, training=training)
    out2 = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)
    
    ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)
    
    return out3, attn_weights_block1, attn_weights_block2
sample_decoder_layer = DecoderLayer(512, 8, 2048)

sample_decoder_layer_output, _, _ = sample_decoder_layer(
    tf.random.uniform((64, 50, 512)), sample_encoder_layer_output, 
    False, None, None)

sample_decoder_layer_output.shape  # (batch_size, target_seq_len, d_model)
TensorShape([64, 50, 512])
编码器(Encoder)

编码器 包括:

  1. 输入嵌入(Input Embedding)
  2. 位置编码(Positional Encoding)
  3. N 个编码器层(encoder layers)

输入经过嵌入(embedding)后,该嵌入与位置编码相加。该加法结果的输出是编码器层的输入。编码器的输出是解码器的输入。

class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
               maximum_position_encoding, rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers
    
    self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, 
                                            self.d_model)
    
    
    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]
  
    self.dropout = tf.keras.layers.Dropout(rate)
        
  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]
    
    # 将嵌入和位置编码相加。
    x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)
    
    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask)
    
    return x  # (batch_size, input_seq_len, d_model)
sample_encoder = Encoder(num_layers=2, d_model=512, num_heads=8, 
                         dff=2048, input_vocab_size=8500,
                         maximum_position_encoding=10000)

sample_encoder_output = sample_encoder(tf.random.uniform((64, 62)), 
                                       training=False, mask=None)

print (sample_encoder_output.shape)  # (batch_size, input_seq_len, d_model)
(64, 62, 512)
解码器(Decoder)

解码器包括:

  1. 输出嵌入(Output Embedding)
  2. 位置编码(Positional Encoding)
  3. N 个解码器层(decoder layers)

目标(target)经过一个嵌入后,该嵌入和位置编码相加。该加法结果是解码器层的输入。解码器的输出是最后的线性层的输入。

class Decoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
               maximum_position_encoding, rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers
    
    self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
    
    self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(rate)
    
  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):

    seq_len = tf.shape(x)[1]
    attention_weights = {}
    
    x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]
    
    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                             look_ahead_mask, padding_mask)
      
      attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
      attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
    
    # x.shape == (batch_size, target_seq_len, d_model)
    return x, attention_weights
sample_decoder = Decoder(num_layers=2, d_model=512, num_heads=8, 
                         dff=2048, target_vocab_size=8000,
                         maximum_position_encoding=5000)

output, attn = sample_decoder(tf.random.uniform((64, 26)), 
                              enc_output=sample_encoder_output, 
                              training=False, look_ahead_mask=None, 
                              padding_mask=None)

output.shape, attn['decoder_layer2_block2'].shape
(TensorShape([64, 26, 512]), TensorShape([64, 8, 26, 62]))
创建 Transformer

Transformer 包括编码器,解码器和最后的线性层。解码器的输出是线性层的输入,返回线性层的输出。

class Transformer(tf.keras.Model):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 
               target_vocab_size, pe_input, pe_target, rate=0.1):
    super(Transformer, self).__init__()

    self.encoder = Encoder(num_layers, d_model, num_heads, dff, 
                           input_vocab_size, pe_input, rate)

    self.decoder = Decoder(num_layers, d_model, num_heads, dff, 
                           target_vocab_size, pe_target, rate)

    self.final_layer = tf.keras.layers.Dense(target_vocab_size)
    
  def call(self, inp, tar, training, enc_padding_mask, 
           look_ahead_mask, dec_padding_mask):

    enc_output = self.encoder(inp, training, enc_padding_mask)  # (batch_size, inp_seq_len, d_model)
    
    # dec_output.shape == (batch_size, tar_seq_len, d_model)
    dec_output, attention_weights = self.decoder(
        tar, enc_output, training, look_ahead_mask, dec_padding_mask)
    
    final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)
    
    return final_output, attention_weights
sample_transformer = Transformer(
    num_layers=2, d_model=512, num_heads=8, dff=2048, 
    input_vocab_size=8500, target_vocab_size=8000, 
    pe_input=10000, pe_target=6000)

temp_input = tf.random.uniform((64, 62))
temp_target = tf.random.uniform((64, 26))

fn_out, _ = sample_transformer(temp_input, temp_target, training=False, 
                               enc_padding_mask=None, 
                               look_ahead_mask=None,
                               dec_padding_mask=None)

fn_out.shape  # (batch_size, tar_seq_len, target_vocab_size)
TensorShape([64, 26, 8000])
配置超参数(hyperparameters)

为了让本示例小且相对较快,已经减小了num_layers、 d_model 和 dff 的值。

Transformer 的基础模型使用的数值为:num_layers=6d_model = 512dff = 2048。关于所有其他版本的 Transformer,请查阅论文。

Note:通过改变以下数值,您可以获得在许多任务上达到最先进水平的模型。

num_layers = 4
d_model = 128
dff = 512
num_heads = 8

input_vocab_size = tokenizer_pt.vocab_size + 2
target_vocab_size = tokenizer_en.vocab_size + 2
dropout_rate = 0.1
优化器(Optimizer)

根据论文中的公式,将 Adam 优化器与自定义的学习速率调度程序(scheduler)配合使用。

l r a t e = d m o d e l − 0.5 ∗ m i n ( s t e p _ n u m − 0.5 , s t e p _ n u m ∗ w a r m u p _ s t e p s − 1.5 ) Large{lrate = d_{model}^{-0.5} * min(step{_}num^{-0.5}, step{_}num * warmup{_}steps^{-1.5})} lrate=dmodel−0.5​∗min(step_num−0.5,step_num∗warmup_steps−1.5)

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super(CustomSchedule, self).__init__()
    
    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps
    
  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)
    
    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)
temp_learning_rate_schedule = CustomSchedule(d_model)

plt.plot(temp_learning_rate_schedule(tf.range(40000, dtype=tf.float32)))
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")

Text(0.5, 0, 'Train Step')
损失函数与指标(Loss and metrics)

由于目标序列是填充(padded)过的,因此在计算损失函数时,应用填充遮挡非常重要。

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask
  
  return tf.reduce_mean(loss_)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_accuracy')
训练与检查点(Training and checkpointing)
transformer = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=dropout_rate)
def create_masks(inp, tar):
  # 编码器填充遮挡
  enc_padding_mask = create_padding_mask(inp)
  
  # 在解码器的第二个注意力模块使用。
  # 该填充遮挡用于遮挡编码器的输出。
  dec_padding_mask = create_padding_mask(inp)
  
  # 在解码器的第一个注意力模块使用。
  # 用于填充(pad)和遮挡(mask)解码器获取到的输入的后续标记(future tokens)。
  look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
  dec_target_padding_mask = create_padding_mask(tar)
  combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
  
  return enc_padding_mask, combined_mask, dec_padding_mask

创建检查点的路径和检查点管理器(manager)。这将用于在每 n 个周期(epochs)保存检查点。

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# 如果检查点存在,则恢复最新的检查点。
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

目标(target)被分成了 tar_inp 和 tar_real。tar_inp 作为输入传递到解码器。tar_real 是位移了 1 的同一个输入:在 tar_inp 中的每个位置,tar_real 包含了应该被预测到的下一个标记(token)。

例如,sentence = “SOS A lion in the jungle is sleeping EOS”

tar_inp = “SOS A lion in the jungle is sleeping”

tar_real = “A lion in the jungle is sleeping EOS”

Transformer 是一个自回归(auto-regressive)模型:它一次作一个部分的预测,然后使用到目前为止的自身的输出来决定下一步要做什么。

在训练过程中,本示例使用了 teacher-forcing 的方法(就像文本生成教程中一样)。无论模型在当前时间步骤下预测出什么,teacher-forcing 方法都会将真实的输出传递到下一个时间步骤上。

当 transformer 预测每个词时,*自注意力(self-attention)*功能使它能够查看输入序列中前面的单词,从而更好地预测下一个单词。

为了防止模型在期望的输出上达到峰值,模型使用了前瞻遮挡(look-ahead mask)。

EPOCHS = 20
# 该 @tf.function 将追踪-编译 train_step 到 TF 图中,以便更快地
# 执行。该函数专用于参数张量的精确形状。为了避免由于可变序列长度或可变
# 批次大小(最后一批次较小)导致的再追踪,使用 input_signature 指定
# 更多的通用形状。

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]
  
  enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
  
  with tf.GradientTape() as tape:
    predictions, _ = transformer(inp, tar_inp, 
                                 True, 
                                 enc_padding_mask, 
                                 combined_mask, 
                                 dec_padding_mask)
    loss = loss_function(tar_real, predictions)

  gradients = tape.gradient(loss, transformer.trainable_variables)    
  optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
  
  train_loss(loss)
  train_accuracy(tar_real, predictions)

葡萄牙语作为输入语言,英语为目标语言。

for epoch in range(EPOCHS):
  start = time.time()
  
  train_loss.reset_states()
  train_accuracy.reset_states()
  
  # inp -> portuguese, tar -> english
  for (batch, (inp, tar)) in enumerate(train_dataset):
    train_step(inp, tar)
    
    if batch % 50 == 0:
      print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
          epoch + 1, batch, train_loss.result(), train_accuracy.result()))
      
  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))
    
  print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                train_loss.result(), 
                                                train_accuracy.result()))

  print ('Time taken for 1 epoch: {} secsn'.format(time.time() - start))
Epoch 1 Batch 0 Loss 4.7219 Accuracy 0.0000
Epoch 1 Batch 50 Loss 4.2328 Accuracy 0.0016
Epoch 1 Batch 100 Loss 4.1794 Accuracy 0.0129
Epoch 1 Batch 150 Loss 4.1244 Accuracy 0.0174
Epoch 1 Batch 200 Loss 4.0585 Accuracy 0.0199
Epoch 1 Batch 250 Loss 3.9918 Accuracy 0.0216
Epoch 1 Batch 300 Loss 3.9194 Accuracy 0.0240
Epoch 1 Batch 350 Loss 3.8279 Accuracy 0.0277
Epoch 1 Batch 400 Loss 3.7333 Accuracy 0.0308
Epoch 1 Batch 450 Loss 3.6568 Accuracy 0.0342
Epoch 1 Batch 500 Loss 3.5911 Accuracy 0.0373
Epoch 1 Batch 550 Loss 3.5256 Accuracy 0.0402
Epoch 1 Batch 600 Loss 3.4705 Accuracy 0.0434
Epoch 1 Batch 650 Loss 3.4186 Accuracy 0.0468
Epoch 1 Batch 700 Loss 3.3658 Accuracy 0.0503
Epoch 1 Loss 3.3638 Accuracy 0.0504
Time taken for 1 epoch: 860.6076409816742 secs

Epoch 2 Batch 0 Loss 2.5517 Accuracy 0.0987
Epoch 2 Batch 50 Loss 2.6231 Accuracy 0.1023
Epoch 2 Batch 100 Loss 2.5504 Accuracy 0.1040
Epoch 2 Batch 150 Loss 2.5346 Accuracy 0.1066
Epoch 2 Batch 200 Loss 2.5117 Accuracy 0.1089
Epoch 2 Batch 250 Loss 2.4920 Accuracy 0.1107
Epoch 2 Batch 300 Loss 2.4739 Accuracy 0.1128
Epoch 2 Batch 350 Loss 2.4587 Accuracy 0.1146
Epoch 2 Batch 400 Loss 2.4463 Accuracy 0.1163
Epoch 2 Batch 450 Loss 2.4335 Accuracy 0.1181
Epoch 2 Batch 500 Loss 2.4200 Accuracy 0.1196
Epoch 2 Batch 550 Loss 2.4082 Accuracy 0.1210
Epoch 2 Batch 600 Loss 2.3951 Accuracy 0.1223
Epoch 2 Batch 650 Loss 2.3828 Accuracy 0.1235
Epoch 2 Batch 700 Loss 2.3702 Accuracy 0.1247
Epoch 2 Loss 2.3699 Accuracy 0.1248
Time taken for 1 epoch: 786.3172492980957 secs

Epoch 3 Batch 0 Loss 2.1001 Accuracy 0.1340
Epoch 3 Batch 50 Loss 2.1539 Accuracy 0.1425
Epoch 3 Batch 100 Loss 2.1499 Accuracy 0.1427
Epoch 3 Batch 150 Loss 2.1640 Accuracy 0.1438
Epoch 3 Batch 200 Loss 2.1518 Accuracy 0.1442
Epoch 3 Batch 250 Loss 2.1446 Accuracy 0.1453
Epoch 3 Batch 300 Loss 2.1378 Accuracy 0.1457
Epoch 3 Batch 350 Loss 2.1395 Accuracy 0.1464
Epoch 3 Batch 400 Loss 2.1364 Accuracy 0.1470
Epoch 3 Batch 450 Loss 2.1290 Accuracy 0.1476
Epoch 3 Batch 500 Loss 2.1234 Accuracy 0.1481
Epoch 3 Batch 550 Loss 2.1164 Accuracy 0.1487
Epoch 3 Batch 600 Loss 2.1150 Accuracy 0.1495
Epoch 3 Batch 650 Loss 2.1096 Accuracy 0.1503
Epoch 3 Batch 700 Loss 2.1025 Accuracy 0.1512
Epoch 3 Loss 2.1027 Accuracy 0.1512
Time taken for 1 epoch: 794.0543868541718 secs

Epoch 4 Batch 0 Loss 1.8834 Accuracy 0.1562
Epoch 4 Batch 50 Loss 1.9649 Accuracy 0.1642
Epoch 4 Batch 100 Loss 1.9310 Accuracy 0.1645
Epoch 4 Batch 150 Loss 1.9310 Accuracy 0.1655
Epoch 4 Batch 200 Loss 1.9283 Accuracy 0.1663
Epoch 4 Batch 250 Loss 1.9315 Accuracy 0.1671
Epoch 4 Batch 300 Loss 1.9275 Accuracy 0.1680
Epoch 4 Batch 350 Loss 1.9212 Accuracy 0.1691
Epoch 4 Batch 400 Loss 1.9191 Accuracy 0.1703
Epoch 4 Batch 450 Loss 1.9122 Accuracy 0.1713
Epoch 4 Batch 500 Loss 1.9015 Accuracy 0.1720
Epoch 4 Batch 550 Loss 1.8952 Accuracy 0.1728
Epoch 4 Batch 600 Loss 1.8921 Accuracy 0.1740
Epoch 4 Batch 650 Loss 1.8878 Accuracy 0.1751
Epoch 4 Batch 700 Loss 1.8795 Accuracy 0.1759
Epoch 4 Loss 1.8792 Accuracy 0.1759
Time taken for 1 epoch: 823.7429077625275 secs

Epoch 5 Batch 0 Loss 1.7172 Accuracy 0.1832
Epoch 5 Batch 50 Loss 1.7015 Accuracy 0.1932
Epoch 5 Batch 100 Loss 1.7292 Accuracy 0.1948
Epoch 5 Batch 150 Loss 1.7166 Accuracy 0.1955
Epoch 5 Batch 200 Loss 1.7153 Accuracy 0.1957
Epoch 5 Batch 250 Loss 1.7020 Accuracy 0.1964
Epoch 5 Batch 300 Loss 1.6987 Accuracy 0.1970
Epoch 5 Batch 350 Loss 1.6923 Accuracy 0.1975
Epoch 5 Batch 400 Loss 1.6881 Accuracy 0.1983
Epoch 5 Batch 450 Loss 1.6829 Accuracy 0.1993
Epoch 5 Batch 500 Loss 1.6784 Accuracy 0.1999
Epoch 5 Batch 550 Loss 1.6744 Accuracy 0.2006
Epoch 5 Batch 600 Loss 1.6699 Accuracy 0.2011
Epoch 5 Batch 650 Loss 1.6664 Accuracy 0.2019
Epoch 5 Batch 700 Loss 1.6630 Accuracy 0.2027
Saving checkpoint for epoch 5 at ./checkpoints/trainckpt-1
Epoch 5 Loss 1.6633 Accuracy 0.2028
Time taken for 1 epoch: 827.5387122631073 secs

Epoch 6 Batch 0 Loss 1.4437 Accuracy 0.1983
Epoch 6 Batch 50 Loss 1.4899 Accuracy 0.2173
Epoch 6 Batch 100 Loss 1.4944 Accuracy 0.2183
Epoch 6 Batch 150 Loss 1.5143 Accuracy 0.2209
Epoch 6 Batch 200 Loss 1.5165 Accuracy 0.2210
Epoch 6 Batch 250 Loss 1.5101 Accuracy 0.2210
Epoch 6 Batch 300 Loss 1.5062 Accuracy 0.2206
Epoch 6 Batch 350 Loss 1.5027 Accuracy 0.2213
Epoch 6 Batch 400 Loss 1.4987 Accuracy 0.2217
Epoch 6 Batch 450 Loss 1.4963 Accuracy 0.2222
Epoch 6 Batch 500 Loss 1.4915 Accuracy 0.2226
Epoch 6 Batch 550 Loss 1.4905 Accuracy 0.2232
Epoch 6 Batch 600 Loss 1.4900 Accuracy 0.2234
Epoch 6 Batch 650 Loss 1.4854 Accuracy 0.2237
Epoch 6 Batch 700 Loss 1.4820 Accuracy 0.2240
Epoch 6 Loss 1.4821 Accuracy 0.2241
Time taken for 1 epoch: 830.6795747280121 secs

Epoch 7 Batch 0 Loss 1.4233 Accuracy 0.2436
Epoch 7 Batch 50 Loss 1.3115 Accuracy 0.2377
Epoch 7 Batch 100 Loss 1.3084 Accuracy 0.2360
Epoch 7 Batch 150 Loss 1.3231 Accuracy 0.2376
Epoch 7 Batch 200 Loss 1.3231 Accuracy 0.2386
Epoch 7 Batch 250 Loss 1.3235 Accuracy 0.2396
Epoch 7 Batch 300 Loss 1.3212 Accuracy 0.2407
Epoch 7 Batch 350 Loss 1.3174 Accuracy 0.2413
Epoch 7 Batch 400 Loss 1.3142 Accuracy 0.2421
Epoch 7 Batch 450 Loss 1.3117 Accuracy 0.2425
Epoch 7 Batch 500 Loss 1.3088 Accuracy 0.2429
Epoch 7 Batch 550 Loss 1.3039 Accuracy 0.2430
Epoch 7 Batch 600 Loss 1.3014 Accuracy 0.2433
Epoch 7 Batch 650 Loss 1.2985 Accuracy 0.2436
Epoch 7 Batch 700 Loss 1.2960 Accuracy 0.2441
Epoch 7 Loss 1.2959 Accuracy 0.2441
Time taken for 1 epoch: 821.1422040462494 secs

Epoch 8 Batch 0 Loss 1.1575 Accuracy 0.2336
Epoch 8 Batch 50 Loss 1.1336 Accuracy 0.2604
Epoch 8 Batch 100 Loss 1.1431 Accuracy 0.2605
Epoch 8 Batch 150 Loss 1.1392 Accuracy 0.2597
Epoch 8 Batch 200 Loss 1.1435 Accuracy 0.2601
Epoch 8 Batch 250 Loss 1.1474 Accuracy 0.2603
Epoch 8 Batch 300 Loss 1.1514 Accuracy 0.2612
Epoch 8 Batch 350 Loss 1.1500 Accuracy 0.2617
Epoch 8 Batch 400 Loss 1.1478 Accuracy 0.2615
Epoch 8 Batch 450 Loss 1.1434 Accuracy 0.2618
Epoch 8 Batch 500 Loss 1.1409 Accuracy 0.2621
Epoch 8 Batch 550 Loss 1.1402 Accuracy 0.2624
Epoch 8 Batch 600 Loss 1.1420 Accuracy 0.2627
Epoch 8 Batch 650 Loss 1.1414 Accuracy 0.2628
Epoch 8 Batch 700 Loss 1.1419 Accuracy 0.2630
Epoch 8 Loss 1.1417 Accuracy 0.2631
Time taken for 1 epoch: 825.8510296344757 secs

Epoch 9 Batch 0 Loss 0.9726 Accuracy 0.2656
Epoch 9 Batch 50 Loss 1.0043 Accuracy 0.2764
Epoch 9 Batch 100 Loss 1.0109 Accuracy 0.2786
Epoch 9 Batch 150 Loss 1.0185 Accuracy 0.2789
Epoch 9 Batch 200 Loss 1.0228 Accuracy 0.2796
Epoch 9 Batch 250 Loss 1.0270 Accuracy 0.2793
Epoch 9 Batch 300 Loss 1.0282 Accuracy 0.2792
Epoch 9 Batch 350 Loss 1.0275 Accuracy 0.2788
Epoch 9 Batch 400 Loss 1.0277 Accuracy 0.2792
Epoch 9 Batch 450 Loss 1.0261 Accuracy 0.2788
Epoch 9 Batch 500 Loss 1.0275 Accuracy 0.2783
Epoch 9 Batch 550 Loss 1.0286 Accuracy 0.2782
Epoch 9 Batch 600 Loss 1.0302 Accuracy 0.2781
Epoch 9 Batch 650 Loss 1.0297 Accuracy 0.2779
Epoch 9 Batch 700 Loss 1.0284 Accuracy 0.2778
Epoch 9 Loss 1.0286 Accuracy 0.2778
Time taken for 1 epoch: 843.3068943023682 secs

Epoch 10 Batch 0 Loss 0.8247 Accuracy 0.2948
Epoch 10 Batch 50 Loss 0.9233 Accuracy 0.2853
Epoch 10 Batch 100 Loss 0.9149 Accuracy 0.2867
Epoch 10 Batch 150 Loss 0.9106 Accuracy 0.2860
Epoch 10 Batch 200 Loss 0.9154 Accuracy 0.2859
Epoch 10 Batch 250 Loss 0.9152 Accuracy 0.2860
Epoch 10 Batch 300 Loss 0.9197 Accuracy 0.2862
Epoch 10 Batch 350 Loss 0.9247 Accuracy 0.2869
Epoch 10 Batch 400 Loss 0.9243 Accuracy 0.2874
Epoch 10 Batch 450 Loss 0.9283 Accuracy 0.2879
Epoch 10 Batch 500 Loss 0.9309 Accuracy 0.2879
Epoch 10 Batch 550 Loss 0.9352 Accuracy 0.2879
Epoch 10 Batch 600 Loss 0.9369 Accuracy 0.2880
Epoch 10 Batch 650 Loss 0.9395 Accuracy 0.2880
Epoch 10 Batch 700 Loss 0.9413 Accuracy 0.2881
Saving checkpoint for epoch 10 at ./checkpoints/trainckpt-2
Epoch 10 Loss 0.9414 Accuracy 0.2881
Time taken for 1 epoch: 829.3768057823181 secs

Epoch 11 Batch 0 Loss 0.8191 Accuracy 0.2973
Epoch 11 Batch 50 Loss 0.8537 Accuracy 0.3009
Epoch 11 Batch 100 Loss 0.8455 Accuracy 0.3007
Epoch 11 Batch 150 Loss 0.8511 Accuracy 0.2998
Epoch 11 Batch 200 Loss 0.8529 Accuracy 0.3000
Epoch 11 Batch 250 Loss 0.8545 Accuracy 0.2995
Epoch 11 Batch 300 Loss 0.8589 Accuracy 0.2995
Epoch 11 Batch 350 Loss 0.8609 Accuracy 0.2990
Epoch 11 Batch 400 Loss 0.8639 Accuracy 0.2992
Epoch 11 Batch 450 Loss 0.8642 Accuracy 0.2994
Epoch 11 Batch 500 Loss 0.8637 Accuracy 0.2993
Epoch 11 Batch 550 Loss 0.8674 Accuracy 0.2992
Epoch 11 Batch 600 Loss 0.8697 Accuracy 0.2987
Epoch 11 Batch 650 Loss 0.8713 Accuracy 0.2985
Epoch 11 Batch 700 Loss 0.8739 Accuracy 0.2982
Epoch 11 Loss 0.8740 Accuracy 0.2982
Time taken for 1 epoch: 816.3042690753937 secs

Epoch 12 Batch 0 Loss 0.6623 Accuracy 0.2870
Epoch 12 Batch 50 Loss 0.7684 Accuracy 0.3077
Epoch 12 Batch 100 Loss 0.7826 Accuracy 0.3056
Epoch 12 Batch 150 Loss 0.7898 Accuracy 0.3066
Epoch 12 Batch 200 Loss 0.7945 Accuracy 0.3073
Epoch 12 Batch 250 Loss 0.7960 Accuracy 0.3064
Epoch 12 Batch 300 Loss 0.7980 Accuracy 0.3067
Epoch 12 Batch 350 Loss 0.8031 Accuracy 0.3070
Epoch 12 Batch 400 Loss 0.8074 Accuracy 0.3072
Epoch 12 Batch 450 Loss 0.8093 Accuracy 0.3069
Epoch 12 Batch 500 Loss 0.8094 Accuracy 0.3061
Epoch 12 Batch 550 Loss 0.8093 Accuracy 0.3057
Epoch 12 Batch 600 Loss 0.8132 Accuracy 0.3056
Epoch 12 Batch 650 Loss 0.8144 Accuracy 0.3056
Epoch 12 Batch 700 Loss 0.8158 Accuracy 0.3050
Epoch 12 Loss 0.8156 Accuracy 0.3049
Time taken for 1 epoch: 793.5989410877228 secs

Epoch 13 Batch 0 Loss 0.8499 Accuracy 0.3223
Epoch 13 Batch 50 Loss 0.7339 Accuracy 0.3176
Epoch 13 Batch 100 Loss 0.7278 Accuracy 0.3164
Epoch 13 Batch 150 Loss 0.7380 Accuracy 0.3165
Epoch 13 Batch 200 Loss 0.7473 Accuracy 0.3160
Epoch 13 Batch 250 Loss 0.7495 Accuracy 0.3156
Epoch 13 Batch 300 Loss 0.7509 Accuracy 0.3145
Epoch 13 Batch 350 Loss 0.7516 Accuracy 0.3137
Epoch 13 Batch 400 Loss 0.7539 Accuracy 0.3133
Epoch 13 Batch 450 Loss 0.7570 Accuracy 0.3129
Epoch 13 Batch 500 Loss 0.7601 Accuracy 0.3129
Epoch 13 Batch 550 Loss 0.7633 Accuracy 0.3128
Epoch 13 Batch 600 Loss 0.7660 Accuracy 0.3126
Epoch 13 Batch 650 Loss 0.7684 Accuracy 0.3122
Epoch 13 Batch 700 Loss 0.7695 Accuracy 0.3120
Epoch 13 Loss 0.7695 Accuracy 0.3120
Time taken for 1 epoch: 796.3926935195923 secs

Epoch 14 Batch 0 Loss 0.6154 Accuracy 0.3142
Epoch 14 Batch 50 Loss 0.6874 Accuracy 0.3259
Epoch 14 Batch 100 Loss 0.6899 Accuracy 0.3258
Epoch 14 Batch 150 Loss 0.6993 Accuracy 0.3237
Epoch 14 Batch 200 Loss 0.7006 Accuracy 0.3228
Epoch 14 Batch 250 Loss 0.7048 Accuracy 0.3214
Epoch 14 Batch 300 Loss 0.7077 Accuracy 0.3213
Epoch 14 Batch 350 Loss 0.7136 Accuracy 0.3208
Epoch 14 Batch 400 Loss 0.7147 Accuracy 0.3203
Epoch 14 Batch 450 Loss 0.7173 Accuracy 0.3198
Epoch 14 Batch 500 Loss 0.7191 Accuracy 0.3196
Epoch 14 Batch 550 Loss 0.7223 Accuracy 0.3195
Epoch 14 Batch 600 Loss 0.7252 Accuracy 0.3192
Epoch 14 Batch 650 Loss 0.7260 Accuracy 0.3189
Epoch 14 Batch 700 Loss 0.7276 Accuracy 0.3184
Epoch 14 Loss 0.7282 Accuracy 0.3185
Time taken for 1 epoch: 788.9147775173187 secs

Epoch 15 Batch 0 Loss 0.6029 Accuracy 0.2973
Epoch 15 Batch 50 Loss 0.6479 Accuracy 0.3274
Epoch 15 Batch 100 Loss 0.6576 Accuracy 0.3274
Epoch 15 Batch 150 Loss 0.6656 Accuracy 0.3275
Epoch 15 Batch 200 Loss 0.6682 Accuracy 0.3258
Epoch 15 Batch 250 Loss 0.6706 Accuracy 0.3262
Epoch 15 Batch 300 Loss 0.6735 Accuracy 0.3261
Epoch 15 Batch 350 Loss 0.6747 Accuracy 0.3258
Epoch 15 Batch 400 Loss 0.6751 Accuracy 0.3253
Epoch 15 Batch 450 Loss 0.6775 Accuracy 0.3252
Epoch 15 Batch 500 Loss 0.6800 Accuracy 0.3247
Epoch 15 Batch 550 Loss 0.6842 Accuracy 0.3242
Epoch 15 Batch 600 Loss 0.6881 Accuracy 0.3241
Epoch 15 Batch 650 Loss 0.6894 Accuracy 0.3234
Epoch 15 Batch 700 Loss 0.6916 Accuracy 0.3229
Saving checkpoint for epoch 15 at ./checkpoints/trainckpt-3
Epoch 15 Loss 0.6916 Accuracy 0.3229
Time taken for 1 epoch: 782.1990637779236 secs

Epoch 16 Batch 0 Loss 0.6690 Accuracy 0.3421
Epoch 16 Batch 50 Loss 0.6060 Accuracy 0.3344
Epoch 16 Batch 100 Loss 0.6161 Accuracy 0.3300
Epoch 16 Batch 150 Loss 0.6212 Accuracy 0.3302
Epoch 16 Batch 200 Loss 0.6311 Accuracy 0.3311
Epoch 16 Batch 250 Loss 0.6336 Accuracy 0.3299
Epoch 16 Batch 300 Loss 0.6393 Accuracy 0.3298
Epoch 16 Batch 350 Loss 0.6442 Accuracy 0.3298
Epoch 16 Batch 400 Loss 0.6470 Accuracy 0.3299
Epoch 16 Batch 450 Loss 0.6487 Accuracy 0.3296
Epoch 16 Batch 500 Loss 0.6513 Accuracy 0.3291
Epoch 16 Batch 550 Loss 0.6536 Accuracy 0.3286
Epoch 16 Batch 600 Loss 0.6554 Accuracy 0.3283
Epoch 16 Batch 650 Loss 0.6587 Accuracy 0.3282
Epoch 16 Batch 700 Loss 0.6608 Accuracy 0.3277
Epoch 16 Loss 0.6607 Accuracy 0.3277
Time taken for 1 epoch: 786.1227226257324 secs

Epoch 17 Batch 0 Loss 0.5516 Accuracy 0.3311
Epoch 17 Batch 50 Loss 0.5870 Accuracy 0.3383
Epoch 17 Batch 100 Loss 0.5873 Accuracy 0.3369
Epoch 17 Batch 150 Loss 0.5960 Accuracy 0.3367
Epoch 17 Batch 200 Loss 0.6006 Accuracy 0.3349
Epoch 17 Batch 250 Loss 0.6028 Accuracy 0.3336
Epoch 17 Batch 300 Loss 0.6076 Accuracy 0.3341
Epoch 17 Batch 350 Loss 0.6095 Accuracy 0.3332
Epoch 17 Batch 400 Loss 0.6155 Accuracy 0.3337
Epoch 17 Batch 450 Loss 0.6176 Accuracy 0.3332
Epoch 17 Batch 500 Loss 0.6205 Accuracy 0.3326
Epoch 17 Batch 550 Loss 0.6239 Accuracy 0.3328
Epoch 17 Batch 600 Loss 0.6266 Accuracy 0.3324
Epoch 17 Batch 650 Loss 0.6292 Accuracy 0.3322
Epoch 17 Batch 700 Loss 0.6319 Accuracy 0.3322
Epoch 17 Loss 0.6323 Accuracy 0.3322
Time taken for 1 epoch: 792.0246450901031 secs

Epoch 18 Batch 0 Loss 0.4853 Accuracy 0.3095
Epoch 18 Batch 50 Loss 0.5643 Accuracy 0.3417
Epoch 18 Batch 100 Loss 0.5668 Accuracy 0.3438
Epoch 18 Batch 150 Loss 0.5696 Accuracy 0.3420
Epoch 18 Batch 200 Loss 0.5774 Accuracy 0.3422
Epoch 18 Batch 250 Loss 0.5810 Accuracy 0.3416
Epoch 18 Batch 300 Loss 0.5800 Accuracy 0.3397
Epoch 18 Batch 350 Loss 0.5849 Accuracy 0.3396
Epoch 18 Batch 400 Loss 0.5895 Accuracy 0.3388
Epoch 18 Batch 450 Loss 0.5919 Accuracy 0.3380
Epoch 18 Batch 500 Loss 0.5952 Accuracy 0.3380
Epoch 18 Batch 550 Loss 0.5985 Accuracy 0.3379
Epoch 18 Batch 600 Loss 0.6005 Accuracy 0.3373
Epoch 18 Batch 650 Loss 0.6035 Accuracy 0.3369
Epoch 18 Batch 700 Loss 0.6061 Accuracy 0.3365
Epoch 18 Loss 0.6062 Accuracy 0.3364
Time taken for 1 epoch: 792.6422400474548 secs

Epoch 19 Batch 0 Loss 0.5587 Accuracy 0.3567
Epoch 19 Batch 50 Loss 0.5506 Accuracy 0.3490
Epoch 19 Batch 100 Loss 0.5450 Accuracy 0.3447
Epoch 19 Batch 150 Loss 0.5526 Accuracy 0.3441
Epoch 19 Batch 200 Loss 0.5551 Accuracy 0.3425
Epoch 19 Batch 250 Loss 0.5562 Accuracy 0.3412
Epoch 19 Batch 300 Loss 0.5600 Accuracy 0.3408
Epoch 19 Batch 350 Loss 0.5622 Accuracy 0.3411
Epoch 19 Batch 400 Loss 0.5662 Accuracy 0.3411
Epoch 19 Batch 450 Loss 0.5698 Accuracy 0.3412
Epoch 19 Batch 500 Loss 0.5730 Accuracy 0.3411
Epoch 19 Batch 550 Loss 0.5760 Accuracy 0.3409
Epoch 19 Batch 600 Loss 0.5770 Accuracy 0.3400
Epoch 19 Batch 650 Loss 0.5803 Accuracy 0.3402
Epoch 19 Batch 700 Loss 0.5832 Accuracy 0.3399
Epoch 19 Loss 0.5833 Accuracy 0.3399
Time taken for 1 epoch: 789.6633312702179 secs

Epoch 20 Batch 0 Loss 0.5405 Accuracy 0.3438
Epoch 20 Batch 50 Loss 0.5256 Accuracy 0.3551
Epoch 20 Batch 100 Loss 0.5263 Accuracy 0.3514
Epoch 20 Batch 150 Loss 0.5316 Accuracy 0.3499
Epoch 20 Batch 200 Loss 0.5335 Accuracy 0.3485
Epoch 20 Batch 250 Loss 0.5369 Accuracy 0.3469
Epoch 20 Batch 300 Loss 0.5427 Accuracy 0.3473
Epoch 20 Batch 350 Loss 0.5459 Accuracy 0.3468
Epoch 20 Batch 400 Loss 0.5485 Accuracy 0.3463
Epoch 20 Batch 450 Loss 0.5498 Accuracy 0.3457
Epoch 20 Batch 500 Loss 0.5518 Accuracy 0.3453
Epoch 20 Batch 550 Loss 0.5545 Accuracy 0.3442
Epoch 20 Batch 600 Loss 0.5580 Accuracy 0.3441
Epoch 20 Batch 650 Loss 0.5602 Accuracy 0.3437
Epoch 20 Batch 700 Loss 0.5627 Accuracy 0.3431
Saving checkpoint for epoch 20 at ./checkpoints/trainckpt-4
Epoch 20 Loss 0.5631 Accuracy 0.3431
Time taken for 1 epoch: 9551.282238960266 secs
评估(evaluate)

以下步骤用于评估:

  • 用葡萄牙语分词器(tokenizer_pt)编码输入语句。此外,添加开始和结束标记,这样输入就与模型训练的内容相同。这是编码器输入。
  • 解码器输入为 start token == tokenizer_en.vocab_size。
  • 计算填充遮挡和前瞻遮挡。
  • 解码器通过查看编码器输出和它自身的输出(自注意力)给出预测。
  • 选择最后一个词并计算它的 argmax。
  • 将预测的词连接到解码器输入,然后传递给解码器。
  • 在这种方法中,解码器根据它预测的之前的词预测下一个。

Note:这里使用的模型具有较小的能力以保持相对较快,因此预测可能不太正确。要复现论文中的结果,请使用全部数据集,并通过修改上述超参数来使用基础 transformer 模型或者 transformer XL。

def evaluate(inp_sentence):
  start_token = [tokenizer_pt.vocab_size]
  end_token = [tokenizer_pt.vocab_size + 1]
  
  # 输入语句是葡萄牙语,增加开始和结束标记
  inp_sentence = start_token + tokenizer_pt.encode(inp_sentence) + end_token
  encoder_input = tf.expand_dims(inp_sentence, 0)
  
  # 因为目标是英语,输入 transformer 的第一个词应该是
  # 英语的开始标记。
  decoder_input = [tokenizer_en.vocab_size]
  output = tf.expand_dims(decoder_input, 0)
    
  for i in range(MAX_LENGTH):
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        encoder_input, output)
  
    # predictions.shape == (batch_size, seq_len, vocab_size)
    predictions, attention_weights = transformer(encoder_input, 
                                                 output,
                                                 False,
                                                 enc_padding_mask,
                                                 combined_mask,
                                                 dec_padding_mask)
    
    # 从 seq_len 维度选择最后一个词
    predictions = predictions[: ,-1:, :]  # (batch_size, 1, vocab_size)

    predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
    
    # 如果 predicted_id 等于结束标记,就返回结果
    if predicted_id == tokenizer_en.vocab_size+1:
      return tf.squeeze(output, axis=0), attention_weights
    
    # 连接 predicted_id 与输出,作为解码器的输入传递到解码器。
    output = tf.concat([output, predicted_id], axis=-1)

  return tf.squeeze(output, axis=0), attention_weights
def plot_attention_weights(attention, sentence, result, layer):
  fig = plt.figure(figsize=(16, 8))
  
  sentence = tokenizer_pt.encode(sentence)
  
  attention = tf.squeeze(attention[layer], axis=0)
  
  for head in range(attention.shape[0]):
    ax = fig.add_subplot(2, 4, head+1)
    
    # 画出注意力权重
    ax.matshow(attention[head][:-1, :], cmap='viridis')

    fontdict = {'fontsize': 10}
    
    ax.set_xticks(range(len(sentence)+2))
    ax.set_yticks(range(len(result)))
    
    ax.set_ylim(len(result)-1.5, -0.5)
    
    xlabel=['']+[tokenizer_pt.decode([i]) for i in sentence]+['']
    print(xlabel)
    '''
    ax.set_xticklabels(
        ['']+[tokenizer_pt.decode([i]) for i in sentence]+[''], 
        fontdict=fontdict, rotation=90)
    '''
    ax.set_xticklabels(xlabel,  fontdict=fontdict, rotation=90)
    
    test =[tokenizer_en.decode([i]) for i in result if i < tokenizer_en.vocab_size]
    print(len(result))
    print(len(test))
    print(tokenizer_en.vocab_size)
    print(result)
    print(test)
    test =test+['.']
    #ax.set_yticklabels([tokenizer_en.decode([i]) for i in result if i < tokenizer_en.vocab_size], fontdict=fontdict)
    
    #ax.set_yticklabels(xlabel, fontdict=fontdict)
    
    ax.set_yticklabels(test, fontdict=fontdict)
    ax.set_xlabel('Head {}'.format(head+1))
  
  plt.tight_layout()
  plt.show()
def translate(sentence, plot=''):
  result, attention_weights = evaluate(sentence)
  
  predicted_sentence = tokenizer_en.decode([i for i in result 
                                            if i < tokenizer_en.vocab_size])  

  print('Input: {}'.format(sentence))
  print('Predicted translation: {}'.format(predicted_sentence))
  
  if plot:
    plot_attention_weights(attention_weights, sentence, result, plot)
translate("este é um problema que temos que resolver.")
print ("Real translation: this is a problem we have to solve .")
Input: este é um problema que temos que resolver.
Predicted translation: this is a problem that we have to solve .... to do n't cooperate .
Real translation: this is a problem we have to solve .
translate("os meus vizinhos ouviram sobre esta ideia.")
print ("Real translation: and my neighboring homes heard about this idea .")
Input: os meus vizinhos ouviram sobre esta ideia.
Predicted translation: my neighbors heard about this idea .
Real translation: and my neighboring homes heard about this idea .
translate("vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.")
print ("Real translation: so i 'll just share with you some stories very quickly of some magical things that have happened .")
Input: vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.
Predicted translation: so i 'm going to really quickly share with you some magic stories that happened .
Real translation: so i 'll just share with you some stories very quickly of some magical things that have happened .
 import  matplotlib.pyplot  as  plt

您可以为 plot 参数传递不同的层和解码器的注意力模块。

translate("este é o primeiro livro que eu fiz.", plot='decoder_layer4_block2')
print ("Real translation: this is the first book i've ever done.")
Input: este é o primeiro livro que eu fiz.
Predicted translation: this is the first book that i did it .
['', 'este ', 'é ', 'o ', 'primeiro ', 'livro ', 'que ', 'eu ', 'fiz', '.', '']
11
10
8087
tf.Tensor([8087   16   13    3  124  774   10   12   98   19    2], shape=(11,), dtype=int32)
['this ', 'is ', 'the ', 'first ', 'book ', 'that ', 'i ', 'did ', 'it', ' .']
['', 'este ', 'é ', 'o ', 'primeiro ', 'livro ', 'que ', 'eu ', 'fiz', '.', '']
11
10
8087
tf.Tensor([8087   16   13    3  124  774   10   12   98   19    2], shape=(11,), dtype=int32)
['this ', 'is ', 'the ', 'first ', 'book ', 'that ', 'i ', 'did ', 'it', ' .']
['', 'este ', 'é ', 'o ', 'primeiro ', 'livro ', 'que ', 'eu ', 'fiz', '.', '']
11
10
8087
tf.Tensor([8087   16   13    3  124  774   10   12   98   19    2], shape=(11,), dtype=int32)
['this ', 'is ', 'the ', 'first ', 'book ', 'that ', 'i ', 'did ', 'it', ' .']
['', 'este ', 'é ', 'o ', 'primeiro ', 'livro ', 'que ', 'eu ', 'fiz', '.', '']
11
10
8087
tf.Tensor([8087   16   13    3  124  774   10   12   98   19    2], shape=(11,), dtype=int32)
['this ', 'is ', 'the ', 'first ', 'book ', 'that ', 'i ', 'did ', 'it', ' .']
['', 'este ', 'é ', 'o ', 'primeiro ', 'livro ', 'que ', 'eu ', 'fiz', '.', '']
11
10
8087
tf.Tensor([8087   16   13    3  124  774   10   12   98   19    2], shape=(11,), dtype=int32)
['this ', 'is ', 'the ', 'first ', 'book ', 'that ', 'i ', 'did ', 'it', ' .']
['', 'este ', 'é ', 'o ', 'primeiro ', 'livro ', 'que ', 'eu ', 'fiz', '.', '']
11
10
8087
tf.Tensor([8087   16   13    3  124  774   10   12   98   19    2], shape=(11,), dtype=int32)
['this ', 'is ', 'the ', 'first ', 'book ', 'that ', 'i ', 'did ', 'it', ' .']
['', 'este ', 'é ', 'o ', 'primeiro ', 'livro ', 'que ', 'eu ', 'fiz', '.', '']
11
10
8087
tf.Tensor([8087   16   13    3  124  774   10   12   98   19    2], shape=(11,), dtype=int32)
['this ', 'is ', 'the ', 'first ', 'book ', 'that ', 'i ', 'did ', 'it', ' .']
['', 'este ', 'é ', 'o ', 'primeiro ', 'livro ', 'que ', 'eu ', 'fiz', '.', '']
11
10
8087
tf.Tensor([8087   16   13    3  124  774   10   12   98   19    2], shape=(11,), dtype=int32)
['this ', 'is ', 'the ', 'first ', 'book ', 'that ', 'i ', 'did ', 'it', ' .']




Real translation: this is the first book i've ever done.

!pip list
Package                           Version

WARNING: Ignoring invalid distribution -atplotlib (e:anaconda3envsmy_star_spacelibsite-packages)
WARNING: Ignoring invalid distribution -atplotlib (e:anaconda3envsmy_star_spacelibsite-packages)
WARNING: Ignoring invalid distribution -atplotlib (e:anaconda3envsmy_star_spacelibsite-packages)
WARNING: Ignoring invalid distribution -atplotlib (e:anaconda3envsmy_star_spacelibsite-packages)
WARNING: You are using pip version 21.2.4; however, version 21.3.1 is available.
You should consider upgrading via the 'e:anaconda3envsmy_star_spacepython.exe -m pip install --upgrade pip' command.



--------------------------------- -----------
absl-py                           0.12.0
alembic                           1.4.1
allennlp                          2.7.0
allennlp-models                   2.7.0
aniso8601                         9.0.1
anyio                             3.3.2
argon2-cffi                       21.1.0
asgiref                           3.4.1
astunparse                        1.6.3
async-generator                   1.10
atomicwrites                      1.4.0
attrs                             21.2.0
azure-core                        1.18.0
azure-storage-blob                12.8.1
Babel                             2.9.1
backcall                          0.2.0
backports.csv                     1.0.7
backports.entry-points-selectable 1.1.0
base58                            2.1.0
beautifulsoup4                    4.10.0
bertviz                           1.2.0
bleach                            4.1.0
blis                              0.7.4
boto3                             1.18.37
botocore                          1.21.37
bottle                            0.12.19
cached-property                   1.5.2
cachetools                        4.2.2
catalogue                         2.0.6
certifi                           2021.5.30
cffi                              1.14.6
chardet                           4.0.0
charset-normalizer                2.0.4
checklist                         0.0.11
cheroot                           8.5.2
CherryPy                          18.6.1
clang                             5.0
click                             7.1.2
cloudpickle                       1.6.0
colorama                          0.4.4
conan                             1.39.0
configparser                      5.0.2
conllu                            4.4.1
contextvars                       2.4
coverage                          5.5
cryptography                      3.4.8
cycler                            0.10.0
cymem                             2.0.5
Cython                            0.29.14
databricks-cli                    0.15.0
dataclasses                       0.8
datasets                          1.11.0
decorator                         4.4.2
defusedxml                        0.7.1
deprecation                       2.0.7
dill                              0.3.4
distlib                           0.3.2
distro                            1.5.0
docker                            5.0.2
docker-pycreds                    0.4.0
dotmap                            1.3.24
elastic-apm                       6.4.0
elasticsearch                     7.10.0
entrypoints                       0.3
fairscale                         0.4.0
faiss-cpu                         1.7.1.post2
farm                              0.7.1
farm-haystack                     0.8.0
fastapi                           0.68.1
fasteners                         0.16.3
feedparser                        6.0.8
filelock                          3.0.12
Flask                             1.1.4
Flask-Cors                        3.0.10
flask-restplus                    0.13.0
flatbuffers                       1.12
fsspec                            2021.8.1
ftfy                              6.0.3
future                            0.18.2
gast                              0.4.0
gensim                            3.8.3
gitdb                             4.0.7
GitPython                         3.1.20
google-api-core                   2.0.1
google-auth                       1.35.0
google-auth-oauthlib              0.4.6
google-cloud-core                 2.0.0
google-cloud-storage              1.42.0
google-crc32c                     1.1.2
google-pasta                      0.2.0
google-resumable-media            2.0.2
googleapis-common-protos          1.53.0
greenlet                          1.1.1
grpcio                            1.37.1
grpcio-tools                      1.37.1
gunicorn                          20.1.0
h11                               0.12.0
h5py                              3.1.0
httptools                         0.3.0
huggingface-hub                   0.0.16
idna                              3.2
immutables                        0.16
importlib-metadata                4.8.1
importlib-resources               5.2.2
iniconfig                         1.1.1
ipykernel                         5.5.5
ipython                           7.16.1
ipython-genutils                  0.2.0
ipywidgets                        7.6.4
iso-639                           0.4.5
isodate                           0.6.0
itsdangerous                      1.1.0
jaraco.classes                    3.2.1
jaraco.collections                3.4.0
jaraco.functools                  3.3.0
jaraco.text                       3.5.1
jax                               0.2.9
jedi                              0.18.0
Jinja2                            2.11.3
jmespath                          0.10.0
joblib                            1.0.1
json5                             0.9.6
jsonschema                        3.2.0
jupyter                           1.0.0
jupyter-client                    7.0.2
jupyter-console                   6.4.0
jupyter-core                      4.7.1
jupyter-server                    1.11.0
jupyterlab                        3.1.14
jupyterlab-pygments               0.1.2
jupyterlab-server                 2.8.2
jupyterlab-widgets                1.0.1
keras                             2.6.0
Keras-Preprocessing               1.1.2
kiwisolver                        1.3.1
langdetect                        1.0.9
lmdb                              1.2.1
lxml                              4.6.3
Mako                              1.1.5
Markdown                          3.3.4
MarkupSafe                        2.0.1
matplotlib                        3.0.3
mistune                           0.8.4
mkl-fft                           1.3.0
mkl-random                        1.1.1
mkl-service                       2.3.0
mlflow                            1.13.1
more-itertools                    8.9.0
msrest                            0.6.21
multiprocess                      0.70.12.2
munch                             2.5.0
murmurhash                        1.0.5
nbclassic                         0.3.2
nbclient                          0.5.4
nbconvert                         6.0.7
nbformat                          5.1.3
nest-asyncio                      1.5.1
networkx                          2.5.1
nltk                              3.6.2
node-semver                       0.6.1
notebook                          6.4.3
numpy                             1.19.5
oauthlib                          3.1.1
opt-einsum                        3.3.0
overrides                         3.1.0
packaging                         21.0
pandas                            1.1.5
pandocfilters                     1.4.3
parso                             0.8.2
patch-ng                          1.17.4
pathtools                         0.1.2
pathy                             0.6.0
patternfork-nosql                 3.6
pdfminer.six                      20201018
pickleshare                       0.7.5
Pillow                            8.3.2
pip                               21.2.4
platformdirs                      2.3.0
pluggy                            1.0.0
pluginbase                        1.0.1
portend                           2.7.1
preshed                           3.0.5
prometheus-client                 0.11.0
prometheus-flask-exporter         0.18.2
promise                           2.3
prompt-toolkit                    3.0.20
protobuf                          3.17.3
psutil                            5.8.0
py                                1.10.0
py-rouge                          1.1
pyarrow                           5.0.0
pyasn1                            0.4.8
pyasn1-modules                    0.2.8
pycparser                         2.20
pydantic                          1.8.2
Pygments                          2.10.0
PyJWT                             1.7.1
pymilvus                          1.1.2
pyparsing                         2.4.7
pyrsistent                        0.18.0
pytest                            6.2.5
python-dateutil                   2.8.2
python-docx                       0.8.11
python-editor                     1.0.4
python-multipart                  0.0.5
pytz                              2021.1
pywin32                           227
pywinpty                          1.1.4
PyYAML                            5.4.1
pyzmq                             22.2.1
qtconsole                         5.1.1
QtPy                              1.11.0
querystring-parser                1.2.4
rdflib                            5.0.0
regex                             2021.8.28
requests                          2.26.0
requests-oauthlib                 1.3.0
requests-unixsocket               0.2.0
rsa                               4.7.2
s3transfer                        0.5.0
sacremoses                        0.0.45
scikit-learn                      0.24.2
scipy                             1.5.2
seaborn                           0.11.2
Send2Trash                        1.8.0
sentencepiece                     0.1.94
sentry-sdk                        1.3.1
seqeval                           1.2.2
setuptools                        58.0.4
sgmllib3k                         1.0.0
shortuuid                         1.0.1
six                               1.16.0
sklearn                           0.0
smart-open                        5.2.1
smmap                             4.0.0
sniffio                           1.2.0
sortedcontainers                  2.4.0
soupsieve                         2.2.1
spacy                             3.1.2
spacy-legacy                      3.0.8
SPARQLWrapper                     1.8.5
SQLAlchemy                        1.4.23
SQLAlchemy-Utils                  0.37.8
sqlitedict                        1.7.0
sqlparse                          0.4.1
srsly                             2.4.1
starlette                         0.14.2
subprocess32                      3.5.4
tabulate                          0.8.9
tempora                           4.1.1
tensorboard                       2.6.0
tensorboard-data-server           0.6.1
tensorboard-plugin-wit            1.8.0
tensorboardX                      2.4
tensorflow                        2.6.0
tensorflow-datasets               4.4.0
tensorflow-estimator              2.6.0
tensorflow-metadata               1.2.0
termcolor                         1.1.0
terminado                         0.12.1
testpath                          0.5.0
thinc                             8.0.10
threadpoolctl                     2.2.0
tika                              1.24
tokenizers                        0.9.4
toml                              0.10.2
toposort                          1.6
torch                             1.7.1
torchtext                         0.10.0
torchvision                       0.10.0
tornado                           6.1
tox                               3.24.3
tqdm                              4.62.2
traitlets                         4.3.3
transformers                      4.0.0
typer                             0.3.2
typing-extensions                 3.7.4.3
ujson                             4.1.0
urllib3                           1.25.11
uvicorn                           0.15.0
virtualenv                        20.7.2
waitress                          2.0.0
wandb                             0.12.1
wasabi                            0.8.2
wcwidth                           0.2.5
webencodings                      0.5.1
websocket-client                  1.2.1
Werkzeug                          0.16.1
wheel                             0.37.0
widgetsnbextension                3.5.1
wincertstore                      0.2
word2number                       1.1
wrapt                             1.12.1
xxhash                            2.0.2
zc.lockfile                       2.0
zipp                              3.5.0
总结

在本教程中,您已经学习了位置编码,多头注意力,遮挡的重要性以及如何创建一个 transformer。

尝试使用一个不同的数据集来训练 transformer。您可也可以通过修改上述的超参数来创建基础 transformer 或者 transformer XL。您也可以使用这里定义的层来创建 BERT 并训练最先进的模型。此外,您可以实现 beam search 得到更好的预测。

参考博客
  • Tensorflow官网https://tensorflow.google.cn/tutorials/text/transformer

  • Transformer课程:理解语言的 Transformer 模型-位置编码及掩码 (Masking)

  • 自然语言处理NLP星空智能对话机器人系列:理解语言的 Transformer 模型-子词分词器

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5625100.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-15

发表评论

登录后才能评论

评论列表(0条)

保存