currently I'm working in a specific version of Apache Spark (3.1.1) that cannot upgrade. Since that I can't use Apache Sedona and the version 1.3.1 is too slow. My problem is the following code that works for standalone pure python but in cluster mode returns Null for all data.
I suspect that broadcast variables could be the source of the problem but I can't find the problem.
Any suggestion?
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
from pyspark import SparkContext, Broadcast
from shapely import wkt
from shapely.geometry import Point
from rtree import index
spark = SparkSession.builder \
.appName("GeoUDFExample") \
.master("local[*]") \
.getOrCreate()
sc = spark.sparkContext
geometries_df = spark.read.parquet("db_geometrias.geoparquet")
geometries = geometries_df.select("id", "geometry").collect()
rtree_idx = index.Index()
geom_dict = {}
for row in geometries:
geom_id = row["id"]
geom_wkt = row["geometry"]
geom_obj = wkt.loads(geom_wkt)
rtree_idx.insert(geom_id, geom_obj.bounds)
geom_dict[geom_id] = geom_wkt
rtree_broadcast: Broadcast = sc.broadcast(rtree_idx)
geom_dict_broadcast: Broadcast = sc.broadcast(geom_dict)
def point_to_geom_id(lat, lon):
pt = Point(lon, lat)
candidates = list(rtree_broadcast.value.intersection((lon, lat, lon, lat)))
for int_id in candidates:
geom_wkt = geom_dict_broadcast.value[int_id]
geom = wkt.loads(geom_wkt)
if geom.contains(pt):
return int_id
return None
geo_udf = udf(point_to_geom_id, IntegerType())
points_df = spark.createDataFrame([
(1, -78.5, -0.2),
(2, -78.3, -0.1),
(3, -78.6, -0.25)
], ["point_id", "lon", "lat"])
result_df = points_df.withColumn("geom_id", geo_udf("lat", "lon"))
result_df.show()
I tried to debug the problem with fixed return but the same result and change the variables to global but use too much memory.