2

I am trying to create a column in my Spark Dataframe a flag if a column's row is in a separate Dataframe.

This is my main Spark Dataframe (df_main)

+--------+
|main    |
+--------+
|28asA017|
|03G12331|
|1567L044|
|02TGasd8|
|1asd3436|
|A1234567|
|B1234567|
+--------+

This is my reference (df_ref), there are hundreds of rows in this reference so I obviously can't hard code them like this solution or this one

+--------+
|mask_vl |
+--------+
|A1234567|
|B1234567|
...
+--------+

Normally, what I'd do in pandas' dataframe is this:

df_main['is_inref'] = np.where(df_main['main'].isin(df_ref.mask_vl.values), "YES", "NO")

So that I would get this

+--------+--------+
|main |is_inref|
+--------+--------+
|28asA017|NO      |
|03G12331|NO      |
|1567L044|NO      |
|02TGasd8|NO      |
|1asd3436|NO      |
|A1234567|YES     |
|B1234567|YES     |
+--------+--------+

I have tried the following code, but I don't get what the error in the picture means.

df_main = df_main.withColumn('is_inref', "YES" if F.col('main').isin(df_ref) else "NO")
df_main.show(20, False)

Error of the mentioned code

2 Answers 2

1

You are close. I think the additional step that you need, is to explicitly create the list that will contain the values from df_ref.

Please see below an illustration:

# Create your DataFrames
df = spark.createDataFrame(["28asA017","03G12331","1567L044",'02TGasd8','1asd3436','A1234567','B1234567'], "string").toDF("main")
df_ref =  spark.createDataFrame(["A1234567","B1234567"], "string").toDF("mask_vl")

Then, you can create a list and use isin, almost as you have it:

# Imports
from pyspark.sql.functions import col, when

# Create a list with the values of your reference DF
mask_vl_list = df_ref.select("mask_vl").rdd.flatMap(lambda x: x).collect()

# Use isin to check whether the values in your column exist in the list
df_main = df_main.withColumn('is_inref', when(col('main').isin(mask_vl_list), 'YES').otherwise('NO'))

This will give you:

>>> df_main.show()

+--------+--------+
|    main|is_inref|
+--------+--------+
|28asA017|      NO|
|03G12331|      NO|
|1567L044|      NO|
|02TGasd8|      NO|
|1asd3436|      NO|
|A1234567|     YES|
|B1234567|     YES|
+--------+--------+
Sign up to request clarification or add additional context in comments.

3 Comments

Thanks for the quick response @sophocles. Am I correct to understand that the first command (the rdd.flatMap(...).collect()) is to basically convert the dataframe into a list into the main driver node? If so, won't I run into an out of memory exception if the reference becomes huge ?
Welcome. Yes you are right. I don't think that you will run into memory exception problems as this is an efficient approach. You can check out more information here about performance benchmarking to convert a column to a list.
I guess that collect() is not the best solution. If the mask_v1 data frame will grow, it will be a problem.
1

If you want to avoid collect, I advise you to do the next:

df_ref= df_ref
          .withColumnRenamed("mask_v1", "main")
          .withColumn("isPreset", lit("yes"))
      
 main_df= main_df.join(df_ref, Seq("main"), "left_outer")
          .withColumn("is_inref", when(col("isPresent").isNull,
          lit("NO")).otherwise(lit("YES")))

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.