一直卡住的一个点,一个时间序列的评价标准到底是下个时间序列的涨跌趋势加和还是应该为下一行(下一个平均时间内)涨跌行情,print打印一下终于了解了emmm
附上代码
package com.wtx.job014
import org.apache.spark.sql.SparkSession
import org.apache.log4j.Logger
import org.apache.log4j.Level
import java.io.{ BufferedWriter, File, FileWriter }
import org.apache.spark.sql.types.{ DoubleType, IntegerType, StructField, StructType }
import org.apache.spark.sql.{ DataFrame, Row, SparkSession }
import scala.collection.mutable.ListBuffer
import java.util.Date
import java.text.SimpleDateFormat
object demo2 {
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("akka").setLevel(Level.OFF)
val priceDataFileName: String = "C:\Users\86183\Desktop\scala_machine_leraning_projects\ScalaMachineLearningData\比特币高频价格预测\bitstampUSD_1-min_data_2012-01-01_to_2018-03-27.csv"
val outputDataFilePath: String = "output/scala_test_x.csv"
val outputLabelFilePath: String = "output/scala_test_y.csv"
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("bitcoin preprocessing").getOrCreate()
val data = spark.read.format("com.databricks.spark.csv").option("header", "true").load(priceDataFileName)
data.show()
val simpleDateFormat = new SimpleDateFormat("yyyyMMddHH") //timestamp日期格式化
spark.udf.register("changeDate", (x: Int) => simpleDateFormat.format(new Date(x * 1000L)).toInt) //int类型最大值为2147483647,乘1000后超出上限所以要改为long类型
data.createOrReplaceTempView("data")
val data1 = spark.sql("select changeDate(Timestamp) as timestamp,* from data")
println((data.count(), data.columns.size))
data1.show() //日期格式化
//创建delta列,公式为:开盘价与收盘价差值
val dataWithDelta = data.withColumn("Delta", data("Close") - data("Open"))
//scala隐式转化
import org.apache.spark.sql.functions._
import spark.sqlContext.implicits._
//转换为时间序列,差值大于0时, label值为1,差值小于等于0时差值为0
val dataWithLabels = dataWithDelta.withColumn("label", when($"Close" - $"Open" > 0, 1).otherwise(0))
rollingWindow(dataWithLabels, 22, outputDataFilePath, outputLabelFilePath)
spark.stop()
}
val dropFirstCount: Int = 612000
def rollingWindow(data: DataFrame, window: Int, xFileName: String, yFileName: String): Unit = {
var i = 0
val xWriter = new BufferedWriter(new FileWriter(new File(xFileName)))
val yWriter = new BufferedWriter(new FileWriter(new File(yFileName)))
val zippedData = data.rdd.zipWithIndex().collect() //创建一个有序集合并可以访问当前次遍历时自动创建的计数器次数
System.gc() //执行System.gc()函数的作用只是提醒或告诉虚拟机,希望进行一次垃圾回收
val dataStratified = zippedData.drop(dropFirstCount) //删除数据,清洗数据集
while (i < dataStratified.length - window) { //遍历数据集
val x = dataStratified.slice(i, i + window).map(r => r._1.getAs[Double]("Delta")).toList //使用slice 进行获取from 到 until之间的元素 取出iterator.slice(n,m)中第n到第m-n个元素
dataStratified.slice(i, i + window).foreach(f=>{println(f)})
val y = dataStratified.apply(i + window)._1.getAs[Integer]("label")
val stringToWrite = x.mkString(",")
println("+++++++++第" + i + "次" + "x: " + stringToWrite)
println("第" + i + "次" + "y: " + dataStratified.apply(i + window)+" -----------------"+"\n")
xWriter.write(stringToWrite + "\n")
yWriter.write(y + "\n")
i += 1
if (i % 10 == 0) {
xWriter.flush() //压出管道中缓冲区数据防止数据死锁造成效率低下以及硬盘写入负载过高
yWriter.flush()
}
}
xWriter.close()
yWriter.close()
}
}
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)