您可以概括当前的代码:
from pyspark.sql.functions import coalesce, lit, col, lead, lagfrom operator import addfrom functools import reducedef weighted_average(c, window, offsets, weights): assert len(weights) == len(offsets) def value(i): if i < 0: return lag(c, -i).over(window) if i > 0: return lead(c, i).over(window) return c # Create a list of Columns # - `value_i * weight_i` if `value_i IS NOT NULL` # - literal 0 otherwise values = [coalesce(value(i) * w, lit(0)) for i, w in zip(offsets, weights)] # or sum(values, lit(0)) return reduce(add, values, lit(0))
它可以用作:
from pyspark.sql.window import Windowdf = spark.createDataframe([ ("a", 1, 1.4), ("a", 2, 8.0), ("a", 3, -1.0), ("a", 4, 2.4), ("a", 5, 99.0), ("a", 6, 3.0), ("a", 7, -1.0), ("a", 8, 0.0)]).toDF("id", "time", "value")w = Window.partitionBy("id").orderBy("time")offsets, delays = [-2, -1, 0, 1, 2], [0.1, 0.20, 0.4, 0.20, 0.1]result = df.withColumn("avg", weighted_average( col("value"), w, offsets, delays))result.show()## +---+----+-----+-------------------+ ## | id|time|value| avg|## +---+----+-----+-------------------+## | a| 1| 1.4| 2.06|## | a| 2| 8.0| 3.5199999999999996|## | a| 3| -1.0| 11.72|## | a| 4| 2.4| 21.66|## | a| 5| 99.0| 40.480000000000004|## | a| 6| 3.0| 21.04|## | a| 7| -1.0| 10.1|## | a| 8| 0.0|0.10000000000000003|## +---+----+-----+-------------------+
注意事项 :
result.withColumn( "normalization_factor", weighted_average(lit(1), w, offsets, delays) ).withColumn( "normalized_avg", col("avg") / col("normalization_factor")).show()## +---+----+-----+-------------------+--------------------+------------------+ ## | id|time|value| avg|normalization_factor| normalized_avg|## +---+----+-----+-------------------+--------------------+------------------+## | a| 1| 1.4| 2.06| 0.7000000000000001|2.9428571428571426|## | a| 2| 8.0| 3.5199999999999996| 0.9|3.9111111111111105|## | a| 3| -1.0| 11.72| 1.0000000000000002|11.719999999999999|## | a| 4| 2.4| 21.66| 1.0000000000000002|21.659999999999997|## | a| 5| 99.0| 40.480000000000004| 1.0000000000000002| 40.48|## | a| 6| 3.0| 21.04| 1.0000000000000002|21.039999999999996|## | a| 7| -1.0| 10.1| 0.9000000000000001| 11.22222222222222|## | a| 8| 0.0|0.10000000000000003| 0.7000000000000001|0.1428571428571429|## +---+----+-----+-------------------+--------------------+------------------+
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)