.. Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at .. http://www.apache.org/licenses/LICENSE-2.0 .. Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ====================================================== Vectorized Python User-defined Table Functions (UDTFs) ====================================================== Spark 4.1 introduces the Vectorized Python user-defined table function (UDTF), a new type of user-defined table-valued function. It can be used via the ``@arrow_udtf`` decorator. Unlike scalar functions that return a single result value from each call, each UDTF is invoked in the ``FROM`` clause of a query and returns an entire table as output. Unlike the traditional Python UDTF that evaluates row by row, the Vectorized Python UDTF lets you directly operate on top of Apache Arrow arrays and column batches. This allows you to leverage vectorized operations and improve the performance of your UDTF. Vectorized Python UDTF Interface -------------------------------- .. currentmodule:: pyspark.sql.functions .. code-block:: python class NameYourArrowPythonUDTF: def __init__(self) -> None: """ Initializes the user-defined table function (UDTF). This is optional. This method serves as the default constructor and is called once when the UDTF is instantiated on the executor side. Any class fields assigned in this method will be available for subsequent calls to the `eval`, `terminate` and `cleanup` methods. Notes ----- - You cannot create or reference the Spark session within the UDTF. Any attempt to do so will result in a serialization error. """ ... def eval(self, *args: Any) -> Iterator[pa.RecordBatch | pa.Table]: """ Evaluates the function using the given input arguments. This method is required and must be implemented. Argument Mapping: - Each provided scalar expression maps to exactly one value in the `*args` list with type `pa.Array`. - Each provided table argument maps to a `pa.RecordBatch` object containing the columns in the order they appear in the provided input table, and with the names computed by the query analyzer. This method is called on every batch of input rows, and can produce zero or more output pyarrow record batches or pyarrow tables. Each element in the output tuple corresponds to one column specified in the return type of the UDTF. Parameters ---------- *args : Any Arbitrary positional arguments representing the input to the UDTF. Yields ------ iterator An iterator of `pa.RecordBatch` or `pa.Table` objects representing a batch of rows in the UDTF result table. Yield as many times as needed to produce multiple batches. Notes ----- - UDTFs can instead accept keyword arguments during the function call if needed. - The `eval` method can raise a `SkipRestOfInputTableException` to indicate that the UDTF wants to skip consuming all remaining rows from the current partition of the input table. This will cause the UDTF to proceed directly to the `terminate` method. - The `eval` method can raise any other exception to indicate that the UDTF should be aborted entirely. This will cause the UDTF to skip the `terminate` method and proceed directly to the `cleanup` method, and then the exception will be propagated to the query processor causing the invoking query to fail. Examples -------- This `eval` method takes a table argument and returns an arrow record batch for each input batch. >>> def eval(self, batch: pa.RecordBatch): ... yield batch This `eval` method takes a table argument and returns a pyarrow table for each input batch. >>> def eval(self, batch: pa.RecordBatch): ... yield pa.table({"x": batch.column(0), "y": batch.column(1)}) This `eval` method takes both table and scalar arguments and returns a pyarrow table for each input batch. >>> def eval(self, batch: pa.RecordBatch, x: pa.Array): ... yield pa.table({"x": x, "y": batch.column(0)}) """ ... def terminate(self) -> Iterator[pa.RecordBatch | pa.Table]: """ Called when the UDTF has successfully processed all input rows. This method is optional to implement and is useful for performing any finalization operations after the UDTF has finished processing all rows. It can also be used to yield additional rows if needed. Table functions that consume all rows in the entire input partition and then compute and return the entire output table can do so from this method as well (please be mindful of memory usage when doing this). If any exceptions occur during input row processing, this method won't be called. Yields ------ iterator An iterator of `pa.RecordBatch` or `pa.Table` objects representing a batch of rows in the UDTF result table. Yield as many times as needed to produce multiple batches. Examples -------- >>> def terminate(self) -> Iterator[pa.RecordBatch | pa.Table]: >>> yield pa.table({"x": pa.array([1, 2, 3])}) """ ... def cleanup(self) -> None: """ Invoked after the UDTF completes processing input rows. This method is optional to implement and is useful for final cleanup regardless of whether the UDTF processed all input rows successfully or was aborted due to exceptions. Examples -------- >>> def cleanup(self) -> None: >>> self.conn.close() """ ... Defining the Output Schema -------------------------- The return type of the UDTF defines the schema of the table it outputs. You can specify it in the ``@arrow_udtf`` decorator. It must be either a ``StructType``: .. code-block:: python @arrow_udtf(returnType=StructType().add("c1", StringType()).add("c2", IntegerType())) class YourArrowPythonUDTF: ... or a DDL string representing a struct type: .. code-block:: python @arrow_udtf(returnType="c1 string, c2 int") class YourArrowPythonUDTF: ... Emitting Output Rows -------------------- The `eval` and `terminate` methods then emit zero or more output batches conforming to this schema by yielding ``pa.RecordBatch`` or ``pa.Table`` objects. .. code-block:: python @arrow_udtf(returnType="c1 int, c2 int") class YourArrowPythonUDTF: def eval(self, batch: pa.RecordBatch): yield pa.table({"c1": batch.column(0), "c2": batch.column(1)}) You can also yield multiple pyarrow tables in the `eval` method. .. code-block:: python @arrow_udtf(returnType="c1 int") class YourArrowPythonUDTF: def eval(self, batch: pa.RecordBatch): yield pa.table({"c1": batch.column(0)}) yield pa.table({"c1": batch.column(1)}) You can also yield multiple pyarrow record batches in the `eval` method. .. code-block:: python @arrow_udtf(returnType="c1 int") class YourArrowPythonUDTF: def eval(self, batch: pa.RecordBatch): new_batch = pa.record_batch( {"c1": batch.column(0).slice(0, len(batch) // 2)}) yield new_batch Usage Examples -------------- Here's how to use these UDTFs in DataFrame: .. code-block:: python import pyarrow as pa from pyspark.sql.functions import arrow_udtf @arrow_udtf(returnType="c1 string") class MyArrowPythonUDTF: def eval(self, batch: pa.RecordBatch): yield pa.table({"c1": batch.column("value")}) df = spark.range(10).selectExpr("id", "cast(id as string) as value") MyArrowPythonUDTF(df.asTable()).show() # Result: # +---+ # | c1| # +---+ # | 0| # | 1| # | 2| # | 3| # | 4| # | 5| # | 6| # | 7| # | 8| # | 9| # +---+ # Register the UDTF spark.udtf.register("my_arrow_udtf", MyArrowPythonUDTF) # Use in SQL queries df = spark.sql(""" SELECT * FROM my_arrow_udtf(TABLE(SELECT id, cast(id as string) as value FROM range(10))) """) TABLE Argument -------------- Arrow UDTFs can take a TABLE argument. When your UDTF receives a TABLE argument, its ``eval`` method is called with a ``pyarrow.RecordBatch`` containing the input table’s columns, and any additional scalar/struct expressions are passed as ``pyarrow.Array`` values. Key points: - The TABLE argument is a single ``pa.RecordBatch``; access columns by name or index. - Scalar arguments (including structs) are ``pa.Array`` values, not ``RecordBatch``. - Named and positional arguments are both supported in SQL. Example (DataFrame API): .. code-block:: python import pyarrow as pa import pyarrow.compute as pc from typing import Iterator, Optional from pyspark.sql.functions import arrow_udtf, SkipRestOfInputTableException @arrow_udtf(returnType="value int") class EchoTable: def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]: # Return the input column named "value" as-is yield pa.table({"value": batch.column("value")}) df = spark.range(5).selectExpr("id as value") EchoTable(df.asTable()).show() # Result: # +-----+ # |value| # +-----+ # | 0| # | 1| # | 2| # | 3| # | 4| # +-----+ Example (SQL): TABLE plus a scalar threshold .. code-block:: python import pyarrow as pa import pyarrow.compute as pc from typing import Iterator from pyspark.sql.functions import arrow_udtf # Keep rows with value > threshold; works with SQL using TABLE + scalar argument @arrow_udtf(returnType="partition_key int, value int") class ThresholdFilter: def eval(self, batch: pa.RecordBatch, threshold: pa.Array) -> Iterator[pa.Table]: tbl = pa.table(batch) thr = int(threshold.cast(pa.int64())[0].as_py()) mask = pc.greater(tbl["value"], thr) yield tbl.filter(mask) spark.udtf.register("threshold_filter", ThresholdFilter) spark.createDataFrame([(1, 10), (1, 30), (2, 5)], "partition_key int, value int").createOrReplaceTempView("v") spark.sql( """ SELECT * FROM threshold_filter( TABLE(v), 10 ) ORDER BY partition_key, value """ ).show() # Result: # +-------------+-----+ # |partition_key|value| # +-------------+-----+ # | 1| 30| # +-------------+-----+ PARTITION BY and ORDER BY ------------------------- Arrow UDTFs support ``TABLE(...) PARTITION BY ... ORDER BY ...``. Think of it as “process rows group by group, and in a specific order within each group”. Semantics: - PARTITION BY groups rows by the given keys; your UDTF runs for each group independently. - ORDER BY controls the row order within each group as seen by ``eval``. - ``eval`` may be called multiple times per group; accumulate state and typically emit the group's result in ``terminate``. Example: Aggregation per key with terminate ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PARTITION BY is especially useful for per-group aggregation. ``eval`` may be called multiple times for the same group as rows arrive in batches, so keep running totals in the UDTF instance and emit the final row in ``terminate``. .. code-block:: python import pyarrow as pa import pyarrow.compute as pc from typing import Iterator from pyspark.sql.functions import arrow_udtf @arrow_udtf(returnType="user_id int, total_amount int, rows int") class SumPerUser: def __init__(self): self._user = None self._sum = 0 self._count = 0 def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]: tbl = pa.table(batch) # All rows in this batch belong to the same user within a partition self._user = pc.unique(tbl["user_id"]).to_pylist()[0] self._sum += pc.sum(tbl["amount"]).as_py() self._count += tbl.num_rows return iter(()) # emit once in terminate def terminate(self) -> Iterator[pa.Table]: if self._user is not None: yield pa.table({ "user_id": pa.array([self._user], pa.int32()), "total_amount": pa.array([self._sum], pa.int32()), "rows": pa.array([self._count], pa.int32()), }) spark.udtf.register("sum_per_user", SumPerUser) spark.createDataFrame( [(1, 10), (2, 5), (1, 20), (2, 15), (3, 7)], "user_id int, amount int", ).createOrReplaceTempView("purchases") spark.sql( """ SELECT * FROM sum_per_user( TABLE(purchases) PARTITION BY user_id ) ORDER BY user_id """ ).show() # Result: # +-------+------------+----+ # |user_id|total_amount|rows| # +-------+------------+----+ # | 1| 30| 2| # | 2| 20| 2| # | 3| 7| 1| # +-------+------------+----+ Why terminate? ``eval`` may run multiple times per group if the input is split into several batches. Emitting the aggregated row in ``terminate`` guarantees exactly one output row per group after all its rows have been processed. Example: Top reviews per product using ORDER BY ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code-block:: python import pyarrow as pa import pyarrow.compute as pc from typing import Iterator, Optional from pyspark.sql.functions import arrow_udtf, SkipRestOfInputTableException @arrow_udtf(returnType="product_id int, review_id int, rating int, review string") class TopReviewsPerProduct: TOP_K = 3 def __init__(self): self._product = None self._seen = 0 self._batches: list[pa.Table] = [] self._top: Optional[pa.Table] = None def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]: tbl = pa.table(batch) if tbl.num_rows == 0: return iter(()) products = pc.unique(tbl["product_id"]).to_pylist() assert len(products) == 1, f"Expected one product per batch, saw {products}" product = products[0] if self._product is None: self._product = product else: assert self._product == product, f"Mixed products {self._product} and {product}" self._batches.append(tbl) self._seen += tbl.num_rows if self._seen >= self.TOP_K and self._top is None: combined = pa.concat_tables(self._batches) self._top = combined.slice(0, self.TOP_K) raise SkipRestOfInputTableException( f"Collected top {self.TOP_K} reviews for product {self._product}" ) return iter(()) def terminate(self) -> Iterator[pa.Table]: if self._product is None: return iter(()) if self._top is None: combined = pa.concat_tables(self._batches) if self._batches else pa.table({}) limit = min(self.TOP_K, self._seen) self._top = combined.slice(0, limit) yield self._top spark.udtf.register("top_reviews_per_product", TopReviewsPerProduct) spark.createDataFrame( [ (101, 1, 5, "Amazing battery life"), (101, 2, 5, "Still great after a month"), (101, 3, 4, "Solid build"), (101, 4, 3, "Average sound"), (202, 5, 5, "My go-to lens"), (202, 6, 4, "Sharp and bright"), (202, 7, 4, "Great value"), ], "product_id int, review_id int, rating int, review string", ).createOrReplaceTempView("reviews") spark.sql( """ SELECT * FROM top_reviews_per_product( TABLE(reviews) PARTITION BY (product_id) ORDER BY (rating DESC, review_id) ) ORDER BY product_id, rating DESC, review_id """ ).show() # Result: # +----------+---------+------+--------------------------+ # |product_id|review_id|rating|review | # +----------+---------+------+--------------------------+ # | 101| 1| 5|Amazing battery life | # | 101| 2| 5|Still great after a month | # | 101| 3| 4|Solid build | # | 202| 5| 5|My go-to lens | # | 202| 6| 4|Sharp and bright | # | 202| 7| 4|Great value | # +----------+---------+------+--------------------------+ Best Practices -------------- - Stream work from :py:meth:`eval` when possible. Yielding one ``pa.Table`` per Arrow batch keeps memory bounded and shortens feedback loops; reserve :py:meth:`terminate` for true per-partition operations. - Keep per-partition state tiny and reset it promptly. If you only need the first *N* rows, raise :py:class:`~pyspark.sql.functions.SkipRestOfInputTableException` after collecting them so Spark skips the rest of the partition. - Guard external calls with short timeouts and operate on the current batch instead of deferring to ``terminate``; this avoids giant buffers and keeps retries narrow. When to use Arrow UDTFs vs Other UDTFs ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - Prefer ``arrow_udtf`` when the logic is naturally vectorised, you can stay within Python, and the input/output schema is Arrow-friendly. You gain batch-friendly performance and native interoperability with PySpark DataFrames. - Stick with the classic (row-based) Python UDTF when you only need simple per-row expansion, or when your logic depends on Python objects that Arrow cannot represent cleanly. - Use SQL UDTFs if the functionality is performance critical and the logic can be represented in SQL. More Examples ------------- Example: Simple anomaly detection per device ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Compute simple per-device stats and return them from ``terminate``; this pattern is useful for anomaly detection workflows that first summarize distributions by key. .. code-block:: python import pyarrow as pa import pyarrow.compute as pc from typing import Iterator from pyspark.sql.functions import arrow_udtf @arrow_udtf(returnType="device_id int, count int, mean double, stddev double, max_value int") class DeviceStats: def __init__(self): self._device = None self._count = 0 self._sum = 0.0 self._sumsq = 0.0 self._max = None def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]: tbl = pa.table(batch) self._device = pc.unique(tbl["device_id"]).to_pylist()[0] vals = tbl["reading"].cast(pa.float64()) self._count += len(vals) self._sum += pc.sum(vals).as_py() or 0.0 self._sumsq += pc.sum(pc.multiply(vals, vals)).as_py() or 0.0 cur_max = pc.max(vals).as_py() self._max = cur_max if self._max is None else max(self._max, cur_max) return iter(()) def terminate(self) -> Iterator[pa.Table]: if self._device is not None and self._count > 0: mean = self._sum / self._count var = max(self._sumsq / self._count - mean * mean, 0.0) std = var ** 0.5 # Round to 2 decimal places for display mean_rounded = round(mean, 2) std_rounded = round(std, 2) yield pa.table({ "device_id": pa.array([self._device], pa.int32()), "count": pa.array([self._count], pa.int32()), "mean": pa.array([mean_rounded], pa.float64()), "stddev": pa.array([std_rounded], pa.float64()), "max_value": pa.array([int(self._max)], pa.int32()), }) spark.udtf.register("device_stats", DeviceStats) spark.createDataFrame( [(1, 10), (1, 12), (1, 100), (2, 5), (2, 7)], "device_id int, reading int", ).createOrReplaceTempView("readings") spark.sql( """ SELECT * FROM device_stats( TABLE(readings) PARTITION BY device_id ) ORDER BY device_id """ ).show() # Result: # +---------+-----+-----+------+---------+ # |device_id|count| mean|stddev|max_value| # +---------+-----+-----+------+---------+ # | 1| 3|40.67| 41.96| 100| # | 2| 2| 6.0| 1.0| 7| # +---------+-----+-----+------+---------+ Example: Arrow UDTFs as RDD map-style transforms ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Arrow UDTFs can replace many ``RDD.map``/``flatMap``-style transforms with better performance and first-class SQL integration. Instead of mapping row-by-row in Python, you work on Arrow batches and return a table. Example: tokenize text into words (flatMap-like) .. code-block:: python import pyarrow as pa import pyarrow.compute as pc from typing import Iterator from pyspark.sql.functions import arrow_udtf @arrow_udtf(returnType="doc_id int, word string") class Tokenize: def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]: tbl = pa.table(batch) # Split on whitespace; build flat arrays for (doc_id, word) doc_ids: list[int] = [] words: list[str] = [] for doc_id, text in zip(tbl["doc_id"].to_pylist(), tbl["text"].to_pylist()): for w in (text or "").split(): doc_ids.append(doc_id) words.append(w) if doc_ids: yield pa.table({"doc_id": pa.array(doc_ids, pa.int32()), "word": pa.array(words)}) spark.udtf.register("tokenize", Tokenize) spark.createDataFrame([(1, "spark is fast"), (2, "arrow udtf")], "doc_id int, text string").createOrReplaceTempView("docs") spark.sql("SELECT * FROM tokenize(TABLE(docs)) ORDER BY doc_id, word").show() # Result: # +------+-----+ # |doc_id| word| # +------+-----+ # | 1| fast| # | 1| is| # | 1|spark| # | 2|arrow| # | 2| udtf| # +------+-----+