张量流中的二进制搜索和内插

张量流中的二进制搜索和内插,第1张

张量流中的二进制搜索和内插

我不知道您的错误来源,但我可以告诉您,这

tf.while_loop
很可能非常缓慢。您可以实现没有循环的线性插值,如下所示:

import numpy as npimport tensorflow as tfxaxis = tf.placeholder(tf.float32, shape=100, name='xaxis')yaxis = tf.placeholder(tf.float32, shape=100, name='yaxis')query = tf.placeholder(tf.float32, name='query')# Add additional elements at the beginning and end for extrapolationxaxis_pad = tf.concat([[tf.minimum(query - 1, xaxis[0])], xaxis, [tf.maximum(query + 1, xaxis[-1])]], axis=0)yaxis_pad = tf.concat([yaxis[:1], yaxis, yaxis[-1:]], axis=0)# Find the index of the interval containing querycmp = tf.cast(query >= xaxis_pad, dtype=tf.int32)diff = cmp[1:] - cmp[:-1]idx = tf.argmin(diff)# Interpolatealpha = (query - xaxis_pad[idx]) / (xaxis_pad[idx + 1] - xaxis_pad[idx])res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]# Test with f(x) = 2 * xq = 5.4x = np.arange(100)y = 2 * xwith tf.Session() as sess:    q_interp = sess.run(res, feed_dict={xaxis: x, yaxis: y, query: q})print(q_interp)>>> 10.8

填充部分只是为了避免麻烦(如果您将值传递到范围之外),否则只是比较和查找值开始大于的问题

query



欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5662001.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存