Wednesday, December 27, 2017

Spark dataframes notes


As I wean myself off inefficient (but easy), maintenance-mode RDDs, these are some notes I made on Datasets and DataFrames.

Datasets for all

In this example, I was looking at bank account transactions for which I wanted to build histograms of nominal values. I created my case class so:

case class Record(acc: String, histo: String, name: String, value: Double, number: Int)

and wanted to create histogram buckets of usage data. I won't bore you with the details but suffice to say I had a function that looked like:

  type Bucket[T, U] = (T, U, Int)
  type Buckets[K] = Seq[Bucket[K, Double]]
  def processAll(rows: Iterator[Row]): Map[String, Buckets[String]] = ...

Now I can run on pipe delimited raw data:

  val byAccountKey = rawDataDF.groupByKey(_.getString(IBAN_INDEX))
  val accTimeHistoCase = byAccountKey.flatMapGroups { case(acc, rows) => processAll(rows).flatMap(bs => bs._2.map(x => Record(acc, bs._1, x._1, x._2, x._3)) ) }
  val ds = accTimeHistoCase.as[Record]

only for it to give me runtime exceptions because Spark doesn't know how to serialize my Records. Oops. 

Not to worry. Let's just add a row encoder.

val schema = StructType(
  StructField("acc", StringType, nullable = false) ::
  StructField("histo", StringType, nullable = false) ::
  StructField("name", StringType, nullable = false) ::
  StructField("value", DoubleType, nullable = false) ::
  StructField("number", IntegerType, nullable = false) ::
  Nil)
implicit val recordEncoder = RowEncoder(schema)

and we're good to go.

User Defined Functions

... except my data is in a Continental format where commas are used as decimal points. This causes confusion when creating Doubles. So, we can add a UDF so:

def commasAsPeriods(x: String): Double = x.replace(",", ".").toDouble
val numFormatter = udf(commasAsPeriods _)
val df = rawDataDF.withColumn("CONTINENTAL_NUMBER", numFormatter('CONTINENTAL_NUMBER))

and now we can parse the whole file.

Built-in Functions

There are a plethora of functions for Spark Datasets. For instance, if you wanted to do a sort of flatMap from Dataset[K, Seq[V]] to Dataset[K, V], you would use explode. (But be careful that this doesn't consume too much memory. I've seen OutOfMemoryErrors by using this function.)

Similarly, collect_set is sort of the opposite. It aggregates a number of rows into one. For instance, if we wanted to know all the bucket names for a single histogram type, we could write:

scala> ds.groupBy($"histo").agg(collect_set($"name")).show()
+----------------+--------------------+
|           histo|   collect_set(name)|
+----------------+--------------------+
|receivingAccount|[ATXXXXXXXXXXXXXX...|
|     countryCode|[NA, XK, VN, BE, ...|
|       dayOfWeek|[FRIDAY, MONDAY, ...|
|         dayTime|[EVENING, FRIDAY,...|
+----------------+--------------------+

These are fairly simple functions but we can make them arbitrarily complex (see this SO answer for a recipe).

aggRo

Now my Dataset, ds, is bucket data for a given histogram for a given account number.

We can see how much bucket data there is for each histogram.

scala> ds.groupByKey(_.histo).agg(count("name")).show()
+----------------+-----------+
|           value|count(name)|
+----------------+-----------+
|receivingAccount|    1530942|
|     countryCode|     465430|
|       dayOfWeek|     866661|
|         dayTime|     809537|
+----------------+-----------+

Since we're dealing with nominal data, a lot of the bucket names are going to be duplicated. For example, there are many data points for the same histogram and bucket (but obviously with different account numbers). For example, the countryCode histogram will have names of CH, GB, FR etc. Let's see this:

scala> ds.groupBy('histo).agg(countDistinct("name")).show()

+----------------+--------------------+
|           histo|count(DISTINCT name)|
+----------------+--------------------+
|receivingAccount|              732020|
|     countryCode|                 152|
|       dayOfWeek|                   8|
|         dayTime|                  10|
+----------------+--------------------+

Evidently, there are 152 different countries to which our customers send money.

We can, of course, aggregate by multiple columns, thus:

scala> ds.groupBy('histo, 'name).agg(count("name")).show()
+----------------+--------------------+-----------+
|           histo|                name|count(name)|
+----------------+--------------------+-----------+
|     countryCode|                  CH|        956|
.
.

groupBy vs. groupByKey

The method groupBy produces a RelationalGroupedDataset and is "used for untyped aggregates using DataFrames. Grouping is described using column expressions or column names."

Whereas groupByKey produces a KeyValueGroupedDataset that is "used for typed aggregates using Datasets with records grouped by a key-defining discriminator function" (Jacek Laskowski).

There's a subtle difference in the API between using groupBy and groupByKey. The agg() method on the result of the former takes a Column but the agg() of latter takes a TypedColumn.

"In Datasets, typed operations tend to act on TypedColumn instead." Fortunately, "to create a TypedColumn, all we have to do is call as[..]" (GitHub)

So, the following:

scala> ds.groupBy('histo).agg(countDistinct('name')).show()

and

scala> ds.groupByKey(_.histo).agg(countDistinct('name').as[Long]).show()

are equivalent.

Similarly, the functions for groupByKey must use typed. For instance:

scala> import org.apache.spark.sql.expressions.scalalang._
scala> ds.groupByKey(_.histo).agg(typed.avg(_.value)).show()

and

scala> ds.groupBy('histo).agg(avg('value)).show()

are also equivalent.

Differences between Datasets

I'd not heard of anti-joins before but they're a good way to find the elements in one Dataset that are not in another (see the Spark mailing list here). The different types (with examples) can be found here on SO where Spark's "left_anti" is the interesting one.


No comments:

Post a Comment