在pyspark数据帧中将重叠区间列表拆分为非重叠子区间,并检查值在重叠区间是否有效 [英] split a list of overlapping intervals into non overlapping subintervals in a pyspark dataframe and check if values are valid on overlapped intervals
问题描述
我有一个pyspark数据框,其中包含定义每行间隔的 start_time
, end_time
列.如果一个间隔至少与另一个间隔重叠,则该列还包含设置为 True
的列 is_duplicated
.否则设置为 False
.
I have a pyspark dataframe that contains the columns start_time
, end_time
that define an interval per row. It contains as well a column is_duplicated
set to True
if one interval is overlapped by at least another interval; set to False
if not.
有一列 rate
,我想知道子间隔(按定义重叠)是否没有不同的值;并且在这种情况下,我想保留包含在 updated_at
列中的最新更新的记录,作为基本事实.
There is a column rate
, and I want to know if there is not different values for a sub-interval (that is overlapped by definition); and if it is the case, I want to keep the record that contain the latest update contained in the column updated_at
as the ground truth.
在中间步骤中,我正在考虑创建一个列 is_validated
设置为:
In the intermediary step, I was thinking to create a column is_validated
set to:
- 当子间隔不重叠时
-
无
-
True
,当子间隔被另一个包含不同rate
值并且是最后更新的时间间隔重叠时 当子间隔与另一个包含不同 -
False
rate
值的间隔重叠且不最后一次更新时,None
when the sub-interval is not overlapedTrue
when the sub-interval is overlapped by another one containing a differentrate
value and is the last updatedFalse
when the sub-interval is overlapped by another one containing a differentrate
value and is NOT the last updated
注意:中间步骤不是强制性的,我只是为了使解释更清楚而提供.
Note: the intermediary step is not mandatory, I provided it just to make the explanation clearer.
输入:
# So this:
input_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-04 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'), # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20
Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'), # OVERLAP: full overlap for (2,3) with (1,4)
Row(start_time='2018-01-03 00:00:00', end_time='2018-01-05 00:00:00', rate=20, updated_at='2021-02-20 00:00:00'), # OVERLAP: (3,5) and (1,4) and rate=10/20
Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00'), # NO OVERLAP: hole between (5,6)
Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00')] # NO OVERLAP
df = spark.createDataFrame(input_rows)
df.show()
>>> +-------------------+-------------------+----+-------------------+
| start_time| end_time|rate| updated_at|
+-------------------+-------------------+----+-------------------+
|2018-01-01 00:00:00|2018-01-04 00:00:00| 10|2021-02-25 00:00:00|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|2021-02-25 00:00:00|
|2018-01-03 00:00:00|2018-01-05 00:00:00| 20|2021-02-20 00:00:00|
|2018-01-06 00:00:00|2018-01-07 00:00:00| 30|2021-02-25 00:00:00|
|2018-01-07 00:00:00|2018-01-08 00:00:00| 30|2021-02-25 00:00:00|
+-------------------+-------------------+----+-------------------+
# Will become:
tmp_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-02 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=True, is_validated=True),
Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00', is_duplicated=True, is_validated=True),
Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=10, updated_at='2021-02-20 00:00:00', is_duplicated=True, is_validated=False),
Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=20, updated_at='2021-02-25 00:00:00', is_duplicated=True, is_validated=True),
Row(start_time='2018-01-04 00:00:00', end_time='2018-01-05 00:00:00', rate=20, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None),
Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00', is_duplicated=False, is_validated=None)
]
tmp_df = spark.createDataFrame(tmp_rows)
tmp_df.show()
>>>
+-------------------+-------------------+----+-------------------+-------------+------------+
| start_time| end_time|rate| updated_at|is_duplicated|is_validated|
+-------------------+-------------------+----+-------------------+-------------+------------+
|2018-01-01 00:00:00|2018-01-02 00:00:00| 10|2021-02-25 00:00:00| false| null|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|2021-02-25 00:00:00| true| true|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|2021-02-25 00:00:00| true| true|
|2018-01-03 00:00:00|2018-01-04 00:00:00| 10|2021-02-20 00:00:00| true| false|
|2018-01-03 00:00:00|2018-01-04 00:00:00| 20|2021-02-25 00:00:00| true| true|
|2018-01-04 00:00:00|2018-01-05 00:00:00| 20|2021-02-25 00:00:00| false| null|
|2018-01-06 00:00:00|2018-01-07 00:00:00| 30|2021-02-25 00:00:00| false| null|
|2018-01-07 00:00:00|2018-01-08 00:00:00| 30|2021-02-25 00:00:00| false| null|
+-------------------+-------------------+----+-------------------+-------------+------------+
# To give you:
output_rows = [Row(start_time='2018-01-01 00:00:00', end_time='2018-01-02 00:00:00', rate=10),
Row(start_time='2018-01-02 00:00:00', end_time='2018-01-03 00:00:00', rate=10),
Row(start_time='2018-01-03 00:00:00', end_time='2018-01-04 00:00:00', rate=20),
Row(start_time='2018-01-04 00:00:00', end_time='2018-01-05 00:00:00', rate=20),
Row(start_time='2018-01-06 00:00:00', end_time='2018-01-07 00:00:00', rate=30),
Row(start_time='2018-01-07 00:00:00', end_time='2018-01-08 00:00:00', rate=30)
]
final_df = spark.createDataFrame(output_rows)
final_df.show()
>>>
+-------------------+-------------------+----+
| start_time| end_time|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00| 10|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|
|2018-01-03 00:00:00|2018-01-04 00:00:00| 10|
|2018-01-04 00:00:00|2018-01-05 00:00:00| 20|
|2018-01-06 00:00:00|2018-01-07 00:00:00| 30|
|2018-01-07 00:00:00|2018-01-08 00:00:00| 30|
+-------------------+-------------------+----+
推荐答案
这有效:
from pyspark.sql import functions as F, Row, SparkSession, SQLContext, Window
from pyspark.sql.types import BooleanType
spark = (SparkSession.builder
.master("local")
.appName("Octopus")
.config('spark.sql.autoBroadcastJoinThreshold', -1)
.getOrCreate())
input_rows = [Row(idx=0, interval_start='2018-01-01 00:00:00', interval_end='2018-01-04 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'), # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20
Row(idx=0, interval_start='2018-01-02 00:00:00', interval_end='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'), # OVERLAP: full overlap for (2,3) with (1,4)
Row(idx=0, interval_start='2018-01-03 00:00:00', interval_end='2018-01-05 00:00:00', rate=20, updated_at='2021-02-20 00:00:00'), # OVERLAP: (3,5) and (1,4) and rate=10/20
Row(idx=0, interval_start='2018-01-06 00:00:00', interval_end='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00'), # NO OVERLAP: hole between (5,6)
Row(idx=0, interval_start='2018-01-07 00:00:00', interval_end='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00')] # NO OVERLAP
df = spark.createDataFrame(input_rows)
df.show()
# Compute overlapping intervals
sc = spark.sparkContext
sql_context = SQLContext(sc, spark)
def overlap(start_first, end_first, start_second, end_second):
return ((start_first < start_second < end_first) or (start_first < end_second < end_first)
or (start_second < start_first < end_second) or (start_second < end_first < end_second))
sql_context.registerFunction('overlap', overlap, BooleanType())
df.registerTempTable("df1")
df.registerTempTable("df2")
df = df.cache()
overlap_df = spark.sql("""
SELECT df1.idx, df1.interval_start, df1.interval_end, df1.rate AS rate FROM df1 JOIN df2
ON df1.idx == df2.idx
WHERE overlap(df1.interval_start, df1.interval_end, df2.interval_start, df2.interval_end)
""")
overlap_df = overlap_df.cache()
# Compute NON overlapping intervals
non_overlap_df = df.join(overlap_df, ['interval_start', 'interval_end'], 'leftanti')
# Stack overlapping points
interval_point = overlap_df.select('interval_start').union(overlap_df.select('interval_end'))
interval_point = interval_point.withColumnRenamed('interval_start', 'p').distinct().sort('p')
# Construct continuous overlapping intervals
w = Window.rowsBetween(1, Window.unboundedFollowing)
interval_point = interval_point.withColumn('interval_end', F.min('p').over(w)).dropna(subset=['p', 'interval_end'])
interval_point = interval_point.withColumnRenamed('p', 'interval_start')
# Stack continuous overlapping intervals and non overlapping intervals
df3 = interval_point.select('interval_start', 'interval_end').union(non_overlap_df.select('interval_start', 'interval_end'))
# Point in interval range join
# https://docs.databricks.com/delta/join-performance/range-join.html
df3.registerTempTable("df3")
df.registerTempTable("df")
sql = """SELECT df3.interval_start, df3.interval_end, df.rate, df.updated_at
FROM df3 JOIN df ON df3.interval_start BETWEEN df.interval_start and df.interval_end - INTERVAL 1 seconds"""
df4 = spark.sql(sql)
df4.sort('interval_start').show()
# select non overlapped intervals and keep most up to date rate value for overlapping intervals
(df4.groupBy('interval_start', 'interval_end')
.agg(F.max(F.struct('updated_at', 'rate'))['rate'].alias('rate'))
.orderBy("interval_start")).show()
+-------------------+-------------------+----+
| interval_start| interval_end|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00| 10|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|
|2018-01-03 00:00:00|2018-01-04 00:00:00| 10|
|2018-01-04 00:00:00|2018-01-05 00:00:00| 20|
|2018-01-06 00:00:00|2018-01-07 00:00:00| 30|
|2018-01-07 00:00:00|2018-01-08 00:00:00| 30|
+-------------------+-------------------+----+
这篇关于在pyspark数据帧中将重叠区间列表拆分为非重叠子区间,并检查值在重叠区间是否有效的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!