pyspark использует фрейм данных внутри udf


У меня есть два фрейма данных df1

+---+---+----------+
|  n|val| distances|
+---+---+----------+
|  1|  1|0.27308652|
|  2|  1|0.24969208|
|  3|  1|0.21314497|
+---+---+----------+

И df2

+---+---+----------+
| x1| x2|         w|
+---+---+----------+
|  1|  2|0.03103427|
|  1|  4|0.19012526|
|  1| 10|0.26805446|
|  1|  8|0.26825935|
+---+---+----------+

Я хочу добавить новый столбец в df1 под названием gamma, который будет содержать сумму значения w из df2, Когда df1.n == df2.x1 OR df1.n == df2.x2

Я пытался использовать udf, но, по-видимому, выбор из разных фреймов данных не будет работать, потому что значения должны быть определены перед вычислением

gamma_udf = udf(lambda n: float(df2.filter("x1 = %d OR x2 = %d"%(n,n)).groupBy().sum('w').rdd.map(lambda x: x).collect()[0]), FloatType())
df1.withColumn('gamma1', gamma_udf('n'))
Есть ли способ сделать это с помощью join или groupby без использования циклов?
1 2

1 ответ:

Вы не можете ссылаться на фрейм данных внутри udf. Как вы уже упоминали, эта проблема лучше всего решается с помощью join.

IIUC, вы ищете что-то вроде:

from pyspark.sql import Window
import pyspark.sql.functions as F

df1.alias("L").join(df2.alias("R"), (df1.n == df2.x1) | (df1.n == df2.x2), how="left")\
    .select("L.*", F.sum("w").over(Window.partitionBy("n")).alias("gamma"))\
    .distinct()\
    .show()
#+---+---+----------+----------+
#|  n|val| distances|     gamma|
#+---+---+----------+----------+
#|  1|  1|0.27308652|0.75747334|
#|  3|  1|0.21314497|      null|
#|  2|  1|0.24969208|0.03103427|
#+---+---+----------+----------+

Или, если вам удобнее использовать синтаксис pyspark-sql, Вы можете зарегистрировать временные таблицы и сделать:

df1.registerTempTable("df1")
df2.registerTempTable("df2")

sqlCtx.sql(
    "SELECT DISTINCT L.*, SUM(R.w) OVER (PARTITION BY L.n) AS gamma "
    "FROM df1 L LEFT JOIN df2 R ON L.n = R.x1 OR L.n = R.x2"
).show()
#+---+---+----------+----------+
#|  n|val| distances|     gamma|
#+---+---+----------+----------+
#|  1|  1|0.27308652|0.75747334|
#|  3|  1|0.21314497|      null|
#|  2|  1|0.24969208|0.03103427|
#+---+---+----------+----------+

Объяснение

В обоих случаях мы делаем левое соединение из df1 в df2. Это сохранит все строки в df1 независимо от того, есть ли совпадение.

В предложение join-это условие, которое вы указали в своем вопросе. Таким образом, все строки в df2, где либо x1, либо x2 равно n, будут соединены.

Далее выбираем все строки из левых таблиц плюс группируем по (разбиваем по) n и суммируем значения w. Это позволит получить сумму по всем строкам, которые соответствовали условию соединения, для каждого значения n.

Наконец, мы возвращаем только отдельные строки, чтобы исключить дубликаты.