一般来说,没有文档,因为对于Spark 1.6 / 2.0,大多数相关API都不打算公开。它应该在Spark 2.1.0中更改(请参阅SPARK-7146)。
API是比较复杂的,因为它必须遵循特定的惯例,以使给定
Transformer或
Estimator兼容的
PipelineAPI。对于某些功能,例如读写和网格搜索,可能需要其中一些方法。其他,例如
keyword_only,只是简单的帮手,并非严格要求。
假设您已经为平均参数定义了以下混合:
from pyspark.ml.pipeline import Estimator, Model, Pipelinefrom pyspark.ml.param.shared import *from pyspark.sql.functions import avg, stddev_sampclass HasMean(Params): mean = Param(Params._dummy(), "mean", "mean", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasMean, self).__init__() def setMean(self, value): return self._set(mean=value) def getMean(self): return self.getOrDefault(self.mean)
标准偏差参数:
class HasStandardDeviation(Params): standardDeviation = Param(Params._dummy(), "standardDeviation", "standardDeviation", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasStandardDeviation, self).__init__() def setStddev(self, value): return self._set(standardDeviation=value) def getStddev(self): return self.getOrDefault(self.standardDeviation)
和阈值:
class HasCenteredThreshold(Params): centeredThreshold = Param(Params._dummy(), "centeredThreshold", "centeredThreshold", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasCenteredThreshold, self).__init__() def setCenteredThreshold(self, value): return self._set(centeredThreshold=value) def getCenteredThreshold(self): return self.getOrDefault(self.centeredThreshold)
您可以创建以下基本Estimator内容:
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable from pyspark import keyword_only class NormalDeviation(Estimator, HasInputCol, HasPredictionCol, HasCenteredThreshold, # Available in PySpark >= 2.3.0 # Credits https://stackoverflow.com/a/52467470 # by https://stackoverflow.com/users/234944/benjamin-manns DefaultParamsReadable, DefaultParamsWritable): @keyword_only def __init__(self, inputCol=None, predictionCol=None, centeredThreshold=1.0): super(NormalDeviation, self).__init__() kwargs = self._input_kwargs self.setParams(**kwargs) # Required in Spark >= 3.0 def setInputCol(self, value): """ Sets the value of :py:attr:`inputCol`. """ return self._set(inputCol=value) # Required in Spark >= 3.0 def setPredictionCol(self, value): """ Sets the value of :py:attr:`predictionCol`. """ return self._set(predictionCol=value) @keyword_only def setParams(self, inputCol=None, predictionCol=None, centeredThreshold=1.0): kwargs = self._input_kwargs return self._set(**kwargs) def _fit(self, dataset): c = self.getInputCol() mu, sigma = dataset.agg(avg(c), stddev_samp(c)).first() return NormalDeviationModel( inputCol=c, mean=mu, standardDeviation=sigma, centeredThreshold=self.getCenteredThreshold(), predictionCol=self.getPredictionCol())class NormalDeviationModel(Model, HasInputCol, HasPredictionCol, HasMean, HasStandardDeviation, HasCenteredThreshold, DefaultParamsReadable, DefaultParamsWritable): @keyword_only def __init__(self, inputCol=None, predictionCol=None, mean=None, standardDeviation=None, centeredThreshold=None): super(NormalDeviationModel, self).__init__() kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, inputCol=None, predictionCol=None, mean=None, standardDeviation=None, centeredThreshold=None): kwargs = self._input_kwargs return self._set(**kwargs) def _transform(self, dataset): x = self.getInputCol() y = self.getPredictionCol() threshold = self.getCenteredThreshold() mu = self.getMean() sigma = self.getStddev() return dataset.withColumn(y, (dataset[x] - mu) > threshold * sigma)
最后,它可以按如下方式使用:
df = sc.parallelize([(1, 2.0), (2, 3.0), (3, 0.0), (4, 99.0)]).toDF(["id", "x"])normal_deviation = NormalDeviation().setInputCol("x").setCenteredThreshold(1.0)model = Pipeline(stages=[normal_deviation]).fit(df)model.transform(df).show()## +---+----+----------+## | id| x|prediction|## +---+----+----------+## | 1| 2.0| false|## | 2| 3.0| false|## | 3| 0.0| false|## | 4|99.0| true|## +---+----+----------+
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)