如何在PySpark中创建自定义估算器

如何在PySpark中创建自定义估算器,第1张

如何在PySpark中创建自定义估算器

一般来说,没有文档,因为对于Spark 1.6 / 2.0,大多数相关API都不打算公开。它应该在Spark 2.1.0中更改(请参阅SPARK-7146)。

API是比较复杂的,因为它必须遵循特定的惯例,以使给定

Transformer
Estimator
兼容的
PipelineAPI
。对于某些功能,例如读写和网格搜索,可能需要其中一些方法。其他,例如
keyword_only
,只是简单的帮手,并非严格要求。

假设您已经为平均参数定义了以下混合:

from pyspark.ml.pipeline import Estimator, Model, Pipelinefrom pyspark.ml.param.shared import *from pyspark.sql.functions import avg, stddev_sampclass HasMean(Params):    mean = Param(Params._dummy(), "mean", "mean",         typeConverter=TypeConverters.toFloat)    def __init__(self):        super(HasMean, self).__init__()    def setMean(self, value):        return self._set(mean=value)    def getMean(self):        return self.getOrDefault(self.mean)

标准偏差参数:

class HasStandardDeviation(Params):    standardDeviation = Param(Params._dummy(),        "standardDeviation", "standardDeviation",         typeConverter=TypeConverters.toFloat)    def __init__(self):        super(HasStandardDeviation, self).__init__()    def setStddev(self, value):        return self._set(standardDeviation=value)    def getStddev(self):        return self.getOrDefault(self.standardDeviation)

阈值

class HasCenteredThreshold(Params):    centeredThreshold = Param(Params._dummy(), "centeredThreshold", "centeredThreshold", typeConverter=TypeConverters.toFloat)    def __init__(self):        super(HasCenteredThreshold, self).__init__()    def setCenteredThreshold(self, value):        return self._set(centeredThreshold=value)    def getCenteredThreshold(self):        return self.getOrDefault(self.centeredThreshold)

您可以创建以下基本Estimator内容:

from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable from pyspark import keyword_only  class NormalDeviation(Estimator, HasInputCol,         HasPredictionCol, HasCenteredThreshold,        # Available in PySpark >= 2.3.0         # Credits https://stackoverflow.com/a/52467470        # by https://stackoverflow.com/users/234944/benjamin-manns        DefaultParamsReadable, DefaultParamsWritable):    @keyword_only    def __init__(self, inputCol=None, predictionCol=None, centeredThreshold=1.0):        super(NormalDeviation, self).__init__()        kwargs = self._input_kwargs        self.setParams(**kwargs)    # Required in Spark >= 3.0    def setInputCol(self, value):        """        Sets the value of :py:attr:`inputCol`.        """        return self._set(inputCol=value)    # Required in Spark >= 3.0    def setPredictionCol(self, value):        """        Sets the value of :py:attr:`predictionCol`.        """        return self._set(predictionCol=value)    @keyword_only    def setParams(self, inputCol=None, predictionCol=None, centeredThreshold=1.0):        kwargs = self._input_kwargs        return self._set(**kwargs) def _fit(self, dataset):        c = self.getInputCol()        mu, sigma = dataset.agg(avg(c), stddev_samp(c)).first()        return NormalDeviationModel( inputCol=c, mean=mu, standardDeviation=sigma,  centeredThreshold=self.getCenteredThreshold(), predictionCol=self.getPredictionCol())class NormalDeviationModel(Model, HasInputCol, HasPredictionCol,        HasMean, HasStandardDeviation, HasCenteredThreshold,        DefaultParamsReadable, DefaultParamsWritable):    @keyword_only    def __init__(self, inputCol=None, predictionCol=None,     mean=None, standardDeviation=None,     centeredThreshold=None):        super(NormalDeviationModel, self).__init__()        kwargs = self._input_kwargs        self.setParams(**kwargs)      @keyword_only    def setParams(self, inputCol=None, predictionCol=None,     mean=None, standardDeviation=None,     centeredThreshold=None):        kwargs = self._input_kwargs        return self._set(**kwargs)    def _transform(self, dataset):        x = self.getInputCol()        y = self.getPredictionCol()        threshold = self.getCenteredThreshold()        mu = self.getMean()        sigma = self.getStddev()        return dataset.withColumn(y, (dataset[x] - mu) > threshold * sigma)     

最后,它可以按如下方式使用:

df = sc.parallelize([(1, 2.0), (2, 3.0), (3, 0.0), (4, 99.0)]).toDF(["id", "x"])normal_deviation = NormalDeviation().setInputCol("x").setCenteredThreshold(1.0)model  = Pipeline(stages=[normal_deviation]).fit(df)model.transform(df).show()## +---+----+----------+## | id|   x|prediction|## +---+----+----------+## |  1| 2.0|     false|## |  2| 3.0|     false|## |  3| 0.0|     false|## |  4|99.0|      true|## +---+----+----------+


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5639469.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存