在keras自定义层中进行广播的逐元素乘法

在keras自定义层中进行广播的逐元素乘法,第1张

在keras自定义层中进行广播的逐元素乘法

乘法之前,您需要重复元素以增加形状。您可以使用

K.repeat_elements
它。(
import keras.backend as K

class MyLayer(Layer):    #there are some difficulties for different types of shapes       #let's use a 'repeat_count' instead, increasing only one dimension    def __init__(self, repeat_count,**kwargs):        self.repeat_count = repeat_count        super(MyLayer, self).__init__(**kwargs)    def build(self, input_shape):        #first, let's get the output_shape        output_shape = self.compute_output_shape(input_shape)        weight_shape = (1,) + output_shape[1:] #replace the batch size by 1        self.kernel = self.add_weight(name='kernel',     shape=weight_shape,     initializer='ones',     trainable=True)        super(MyLayer, self).build(input_shape)  # Be sure to call this somewhere!    #here, we need to repeat the elements before multiplying    def call(self, x):        if self.repeat_count > 1:  #we add the extra dimension:  x = K.expand_dims(x, axis=1)  #we replicate the elements  x = K.repeat_elements(x, rep=self.repeat_count, axis=1)        #multiply        return x * self.kernel    #make sure we comput the ouptut shape according to what we did in "call"    def compute_output_shape(self, input_shape):        if self.repeat_count > 1: return (input_shape[0],self.repeat_count) + input_shape[1:]        else: return input_shape


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5003658.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-11-14
下一篇 2022-11-14

发表评论

登录后才能评论

评论列表(0条)

保存