You can use aggregate and map_concat:
import org.apache.spark.sql.functions.{expr, collect_list}
val df = Seq(
(1, Map("k1" -> "v1", "k2" -> "v3")),
(1, Map("k3" -> "v3")),
(2, Map("k4" -> "v4")),
(2, Map("k6" -> "v6", "k5" -> "v5"))
).toDF("id", "data")
val mergeExpr = expr("aggregate(data, map(), (acc, i) -> map_concat(acc, i))")
df.groupBy("id").agg(collect_list("data").as("data"))
.select($"id", mergeExpr.as("merged_data"))
.show(false)
// +---+------------------------------+
// |id |merged_data |
// +---+------------------------------+
// |1 |[k1 -> v1, k2 -> v3, k3 -> v3]|
// |2 |[k4 -> v4, k6 -> v6, k5 -> v5]|
// +---+------------------------------+
With map_concat we concatenate all the Map items of the data column via the aggregate build-in function which allows us to apply the aggregation to the pairs of the list.
Attention: current implementation of map_concat on Spark 2.4.5 it allows co-existence of identical keys. This is most likely a bug since it is not the expected behaviour according to the official documentation. Please be aware of that.
If you want to avoid such a case you can also go for a UDF:
import org.apache.spark.sql.functions.{collect_list, udf}
val mergeMapUDF = udf((data: Seq[Map[String, String]]) => data.reduce(_ ++ _))
df.groupBy("id").agg(collect_list("data").as("data"))
.select($"id", mergeMapUDF($"data").as("merged_data"))
.show(false)
UPDATE (2022-08-27)
- In Spark 3.3.0 the above code doesn't work and the following exception is thrown:
AnalysisException: cannot resolve 'aggregate(`data`, map(), lambdafunction(map_concat(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))' due to data type mismatch: argument 3 requires map<null,null> type, however, 'lambdafunction(map_concat(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable())' is of map<string,string> type.;
Project [id#110, aggregate(data#119, map(), lambdafunction(map_concat(cast(lambda acc#122 as map<string,string>), lambda i#123), lambda acc#122, lambda i#123, false), lambdafunction(lambda id#124, lambda id#124, false)) AS aggregate(data, map(), lambdafunction(map_concat(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))#125]
+- Aggregate [id#110], [id#110, collect_list(data#111, 0, 0) AS data#119]
+- Project [_1#105 AS id#110, _2#106 AS data#111]
+- LocalRelation [_1#105, _2#106]
It seems that map() is initialised as map<null,null> when map<string,string> is expected.
To fix this just cast map() into map<string, string> explicitly with cast(map() as map<string, string>).
Here is the updated code:
val mergeExpr = expr("aggregate(data, cast(map() as map<string,
string>), (acc, i) -> map_concat(acc, i))")
df.groupBy("id").agg(collect_list("data").as("data"))
.select($"id", mergeExpr)
.show(false)
- Regarding the identical keys bug, this seems fixed in the latest versions. If you try to add identical keys an exception is thrown:
Caused by: RuntimeException: Duplicate map key k5 was found, please check the input data. If you want to remove the duplicated keys, you can set spark.sql.mapKeyDedupPolicy to LAST_WIN so that the key inserted at last takes precedence.