微信号:SparkDaily

介绍:每日播报Spark相关技术及资讯,我们坚信Spark才是未来的通用大数据处理框架.

Spark 2.0 机器学习实践:Iris 数据分类

2017-01-10 17:13 Spark技术日报

原文:http://www.flyml.net/2017/01/09/spark-2-0-ml-practice-iris-test/


之前尝试使用Spark MLlib 做机器学习,发现不是非常方便,也可能是在使用习惯上面不太适应(相对 python sklearn).今天尝试使用Spark MLlib 针对Iris数据做一次实践,之后会尝试写一个包装类,将这些步骤简化。


0. 数据准备:


原始的数据以及相应的说明可以到[这里] 下载。 我在这基础之上,增加了header信息。 下载:https://pan.baidu.com/s/1c2d0hpA


如果是可以直接从NFS或者HDFS之类的文件服务里面读csv,会比较方便, 参考下面的python代码:


from pyspark.sql import SQLContext

sqlContext = SQLContext(sc)

df = sqlContext.read.format('com.databricks.spark.csv')

    .options(header='true', inferschema='true')

    .load('iris.csv')

# Displays the content of the DataFrame to stdout

df.show()


在我的环境之中,如果你跟我一样只能从本地读取,就比较麻烦了。 可以参考下面的Java代码:


// 首先读iris 数据

// 因为是从本地读取Sample数据,所以比较麻烦一些~

List<String> lines = FileUtils.readLines(new File("E:\\DataSet\\iris_data.txt"), "UTF-8");

List<Row> data = Lists.newArrayList();

String[] headers = lines.get(0).split(",");

for(String line : lines.subList(1, lines.size())) {

// 前面几个都是double

String[] cells = line.split(",");

Object[] values = new Object[cells.length];

for(int i = 0; i < cells.length - 1; i++) {

values[i] = Double.parseDouble(cells[i]);

}

values[cells.length - 1] = cells[cells.length - 1];

data.add(RowFactory.create(values));

}


// 创建Dataset

StructField[] fields = new StructField[headers.length];

for(int i = 0; i < headers.length - 1; i++) {

fields[i] = new StructField(headers[i], DataTypes.DoubleType, false, Metadata.empty());

}

fields[headers.length - 1] = new StructField(headers[headers.length - 1], DataTypes.StringType, false, Metadata.empty());

StructType schema = new StructType(fields);

Dataset<Row> df = ss.createDataFrame(data, schema);

df.show();


df.show() 的结果如下所示:


+------------+-----------+------------+-----------+-----------+

|sepal_length|sepal_width|petal_length|petal_width|    classes|

+------------+-----------+------------+-----------+-----------+

|         5.1|        3.5|         1.4|        0.2|Iris-setosa|

|         4.9|        3.0|         1.4|        0.2|Iris-setosa|

|         4.7|        3.2|         1.3|        0.2|Iris-setosa|

|         4.6|        3.1|         1.5|        0.2|Iris-setosa|

|         5.0|        3.6|         1.4|        0.2|Iris-setosa|

|         5.4|        3.9|         1.7|        0.4|Iris-setosa|

|         4.6|        3.4|         1.4|        0.3|Iris-setosa|

|         5.0|        3.4|         1.5|        0.2|Iris-setosa|

|         4.4|        2.9|         1.4|        0.2|Iris-setosa|

|         4.9|        3.1|         1.5|        0.1|Iris-setosa|

|         5.4|        3.7|         1.5|        0.2|Iris-setosa|

|         4.8|        3.4|         1.6|        0.2|Iris-setosa|

|         4.8|        3.0|         1.4|        0.1|Iris-setosa|

|         4.3|        3.0|         1.1|        0.1|Iris-setosa|

|         5.8|        4.0|         1.2|        0.2|Iris-setosa|

|         5.7|        4.4|         1.5|        0.4|Iris-setosa|

|         5.4|        3.9|         1.3|        0.4|Iris-setosa|

|         5.1|        3.5|         1.4|        0.3|Iris-setosa|

|         5.7|        3.8|         1.7|        0.3|Iris-setosa|

|         5.1|        3.8|         1.5|        0.3|Iris-setosa|

+------------+-----------+------------+-----------+-----------+

only showing top 20 rows


1. 使用StringIndexer将字符型的label变成index


// 使用StringIndexer将字符型的label变成index

StringIndexer indexer = new StringIndexer()

  .setInputCol("classes")

  .setOutputCol("classesIndex");

Dataset<Row> indexed = indexer.fit(df).transform(df);

indexed.show();


这就是Spark第一个不太方便的地方:不能直接处理String类型的label


处理完成之后,show的结果如下:


+------------+-----------+------------+-----------+-----------+------------+

|sepal_length|sepal_width|petal_length|petal_width|    classes|classesIndex|

+------------+-----------+------------+-----------+-----------+------------+

|         5.1|        3.5|         1.4|        0.2|Iris-setosa|         2.0|

|         4.9|        3.0|         1.4|        0.2|Iris-setosa|         2.0|

|         4.7|        3.2|         1.3|        0.2|Iris-setosa|         2.0|

|         4.6|        3.1|         1.5|        0.2|Iris-setosa|         2.0|


如果需要将Index变回来,那么需要用到IndexToString:


IndexToString converter = new IndexToString()

  .setInputCol("classesIndex")

  .setOutputCol("originalClasses");

Dataset<Row> converted = converter.transform(indexed);


converted.show();


2. 数据模型的创建与验证


在Spark的机器学习之中,有一个很容易让初学者混淆的问题:ml跟mllib有什么区别?


简单的说:


  • spark.mllib中的算法接口是基于RDDs的

  • spark.ml中的算法接口是基于DataFrames / Dataset 的


但是根据作者自己的经验,如果你处理的是CSV格式的数据,除非你现行转换成Libsvm的格式,否则后期处理非常非常的麻烦。具体的处理方式将在后续的文章之中尝试,敬请关注。


尝试使用RDD的方式(mllib)进行分类


使用RDD-based API, 核心就是整合出一个LabeledPoint.


// 将Row --> LabeledPoint

JavaRDD<LabeledPoint> rowRDD = indexed.toJavaRDD().map(new Function<Row, LabeledPoint>() {

@Override

public LabeledPoint call(Row row) throws Exception {

double[] features = new double[4];

for(int i = 0; i < 4; i++) {

features[i] = row.getDouble(i);

}

LabeledPoint point = new LabeledPoint(row.getDouble(6), Vectors.dense(features));

return point;

}

});


当你整合出这个LabeledPoint RDD之后,就直接copy官网代码即可。


就不在这里贴代码了。 比如,如果你采用的是RandomForest, 可以请参考:https://spark.apache.org/docs/2.0.2/mllib-ensembles.html#random-forests


比如在喂给RandomForest的时候,我们需要设置好几个参数:


  • numClasses

    需要提前设置好有那几个类

  • numTrees

    有几棵树

  • categoricalFeaturesInfo

    每一个feature有几个类别?

  • featureSubsetStrategy

    auto: 默认参数。让算法自己决定,每颗树使用几条数据

  • impurity / maxDepth / maxBins / seed

 

3.  检查预测结果


官网使用的是比较简单粗暴的比较方式:


Double testErr =

  1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {

    @Override

    public Boolean call(Tuple2<Double, Double> pl) {

      return !pl._1().equals(pl._2());

    }

  }).count() / testData.count();

System.out.println("Test Error: " + testErr);

System.out.println("Learned classification forest model:\n" + model.toDebugString());


这样确实能看到结果,但是如果想查看比如classification_report等等,Spark自带的类能提供一些比较方便的东西。


在之前官网的基础之上,需要修改predictionAndLabel的数据类型:


JavaRDD<Tuple2<Object, Object>> predictionAndLabels = testData.map(

  new Function<LabeledPoint, Tuple2<Object, Object>>() {

    public Tuple2<Object, Object> call(LabeledPoint p) {

      Double prediction = model.predict(p.features());

      return new Tuple2<Object, Object>(prediction, p.label());

    }

  }

);


有两个要注意的地方:


  • 类型需要是Object, 之前的Double不行

  • 从之前的JavaPairRDD 变成 JavaRDD<Tuple2>


做好这个准备之后,我们就可以调用Metrics相关的工具类了:


// 多分类:

MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());


// 二分类

BinaryClassificationMetrics metrics =

  new BinaryClassificationMetrics(predictionAndLabels.rdd());


// 多标签分类:(相对来说少遇到)

MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd());


一些简单实用的Sample:


System.out.println("准确率:" + metrics.accuracy());

// 准确率:0.9583333333333334



Matrix confusion = metrics.confusionMatrix();

System.out.println("混淆矩阵: \n" + confusion);

// 混淆矩阵: 

// 15.0  1.0   0.0   

// 0.0   18.0  0.0   

// 0.0   1.0   13.0


至此,一个基本的流程算是走通了,但是我们可以看到,在这整个过程之中有一些很不方便的事情:


  1. 读取本地的CSV非常不方便

  2. 不支持String类型的label,需要使用StringIndexer。即使读取原始数据是数值类型,也需要使用StringIndexer, 因为除非使用spark-csv并且设置了inferSchema=true, 否则也自动被认为是String类型的值

  3. 在使用RandomForest的时候,好几个参数需要设置。我感觉有的是应该可以自动设置的。比如:numClasses 、 categoricalFeaturesInfo

  4. 在最后检查结果的时候,比较麻烦:

  • 需要自己选择是multi-class or binary-class.

  • 混淆矩阵缺少label

  • 缺少类似sklearn.classification_report 那种简单明了的report

 

后面会陆续针对这些问题,做一些wrapper。







【长按识别立即关注】 

 品读之后,愿有所获。

 
Spark技术日报 更多文章 基于Python的Spark Streaming+Kafka编程实践 Storm介绍及与Spark Streaming对比 Spark Standalone架构设计要点分析 Spark 2.0 机器学习实践:Iris 数据分类 Spark 2.0 机器学习实践:Iris 数据分类
猜您喜欢 阿里美女员工精彩总结:“从0到1”里没有告诉你的事 Java性能优化指南 ,及唯品会的实战(修订版) 为什么说编程是有史以来最好的工作 Codis作者黄东旭细说分布式Redis架构设计和踩过的那些坑们 FEX 技术周刊 - 2016\/08\/22