Hosted with nbsanity. See source notebook on GitHub.

# %%
import pyarrow as pa
import duckdb
from duckdb.typing import VARCHAR
from duckdb.functional import PythonUDFType
from time import sleep
import threading

import time
from typing import Any


num_calls = 0
average_num_passed_into_each_udf_call = 0
num_calls_currently_in_flight = 0
max_num_calls_currently_in_flight = 0


# Arrow python implementation (operates over 2048 rows at a time)
def my_arrow_udf(url_arr: pa.ChunkedArray) -> pa.Array:
    global num_calls
    global average_num_passed_into_each_udf
    global num_calls_currently_in_flight
    global max_num_calls_currently_in_flight

    with threading.Lock():
        num_calls += 1
        num_calls_currently_in_flight += 1
        max_num_calls_currently_in_flight = max(max_num_calls_currently_in_flight, num_calls_currently_in_flight)

    urls = []
    for chunk in url_arr.chunks:
        chunk_as_list = chunk.to_pylist()
        urls.extend(chunk_as_list)

    average_num_passed_into_each_udf = (len(urls) * (num_calls - 1) + len(chunk_as_list)) / num_calls

    # We sleep just a bit to illustrate the parallelism
    # If this were not using parallelism, we'd sleep for 0.1s * 586 calls, = 58.6 seconds
    # but since we use parallelism, and have 10 threads, we sleep for 0.1s * 58.6 = 5.86 seconds
    sleep(0.1)
    results: list[str] = []
    results = ["foo" for _ in range(len(urls))]

    with threading.Lock():
        num_calls_currently_in_flight -= 1

    return pa.array(results, type=pa.string())


try:
    duckdb.remove_function("my_arrow_udf")
except Exception:
    pass
duckdb.create_function("my_arrow_udf", my_arrow_udf, [VARCHAR], VARCHAR, type=PythonUDFType.ARROW)

# Setup a sample table
duckdb.sql("CREATE OR REPLACE TABLE example_table (id INTEGER, url VARCHAR)")
duckdb.sql(
    """
          INSERT INTO example_table 
          SELECT *, 'https://example.com' as url 
          FROM range(1_200_000)
           """
)
# %%
# Execute and fetch the results
num_calls = 0
average_num_passed_into_each_udf = 0
num_calls_currently_in_flight = 0
max_num_calls_currently_in_flight = 0
start_time: float = time.time()
res_arrow: list[Any] = duckdb.sql(
    "SELECT my_arrow_udf(url) FROM example_table"
).fetchall()  # Will run over all rows simultaneously
end_time: float = time.time()
print(f"Arrow UDF took {end_time - start_time} seconds")
print(f"Number of calls: {num_calls}")
print(f"Number of rows in each chunk per call: {average_num_passed_into_each_udf}")
print(f"Max number of calls ever in flight: {max_num_calls_currently_in_flight}")
Arrow UDF took 6.561621904373169 seconds
Number of calls: 586
Number of rows in each chunk per call: 2048.0
Max number of calls ever in flight: 10