我不知道您的错误来源,但我可以告诉您,这
tf.while_loop很可能非常缓慢。您可以实现没有循环的线性插值,如下所示:
import numpy as npimport tensorflow as tfxaxis = tf.placeholder(tf.float32, shape=100, name='xaxis')yaxis = tf.placeholder(tf.float32, shape=100, name='yaxis')query = tf.placeholder(tf.float32, name='query')# Add additional elements at the beginning and end for extrapolationxaxis_pad = tf.concat([[tf.minimum(query - 1, xaxis[0])], xaxis, [tf.maximum(query + 1, xaxis[-1])]], axis=0)yaxis_pad = tf.concat([yaxis[:1], yaxis, yaxis[-1:]], axis=0)# Find the index of the interval containing querycmp = tf.cast(query >= xaxis_pad, dtype=tf.int32)diff = cmp[1:] - cmp[:-1]idx = tf.argmin(diff)# Interpolatealpha = (query - xaxis_pad[idx]) / (xaxis_pad[idx + 1] - xaxis_pad[idx])res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]# Test with f(x) = 2 * xq = 5.4x = np.arange(100)y = 2 * xwith tf.Session() as sess: q_interp = sess.run(res, feed_dict={xaxis: x, yaxis: y, query: q})print(q_interp)>>> 10.8
填充部分只是为了避免麻烦(如果您将值传递到范围之外),否则只是比较和查找值开始大于的问题
query。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)