Skip to content

LLM Utilities

LLM-powered DataFrame transformations.

This module provides functions that call external LLM APIs (OpenAI / Azure OpenAI) and LangChain embedding models to enrich PySpark DataFrames. Pure DataFrame column operations live in :mod:spark_fuse.utils.transformations.

with_langchain_embeddings

with_langchain_embeddings(df: DataFrame, input_col: str, embeddings: Union['Embeddings', Callable[[], 'Embeddings']], *, output_col: str = 'embedding', batch_size: int = 16, text_splitter: Optional[Union['TextSplitter', Callable[[], 'TextSplitter']]] = None, aggregation: str = 'mean', drop_input: bool = False) -> DataFrame

Add a column of vector embeddings using a LangChain Embeddings model.

The function uses a Pandas UDF to batch calls to embed_documents and reuse a single embeddings instance per executor. Provide either an instantiated LangChain embeddings object or a zero-argument callable that returns one—factories are useful when clients (e.g., OpenAI) are not picklable. Optionally supply a LangChain text splitter to chunk long inputs before embedding; chunk embeddings are combined using aggregation ("mean" or "first").

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame containing the raw text column.

required
input_col str

Name of the column with text to embed.

required
embeddings Union['Embeddings', Callable[[], 'Embeddings']]

LangChain embeddings instance or factory returning one.

required
output_col str

Name of the resulting column containing array<float> vectors.

'embedding'
batch_size int

Number of rows to embed per batch inside the UDF.

16
text_splitter Optional[Union['TextSplitter', Callable[[], 'TextSplitter']]]

Optional LangChain text splitter (or factory) applied before embedding to chunk the text.

None
aggregation str

Strategy to combine chunk embeddings when a splitter is provided. Supported values: "mean" (default) and "first".

'mean'
drop_input bool

Remove input_col from the resulting DataFrame when True.

False

Raises:

Type Description
ValueError

If input_col is missing, batch_size is not positive, the embeddings model returns a length mismatch, or aggregation is invalid.

TypeError

When embeddings is neither an embeddings instance nor a factory producing one, or when text_splitter lacks split_text.

RuntimeError

When the embeddings model or text splitter raises an exception during execution. Spark surfaces these as pyspark.errors.PythonException.

Source code in src/spark_fuse/utils/llm.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def with_langchain_embeddings(
    df: DataFrame,
    input_col: str,
    embeddings: Union["Embeddings", Callable[[], "Embeddings"]],
    *,
    output_col: str = "embedding",
    batch_size: int = 16,
    text_splitter: Optional[Union["TextSplitter", Callable[[], "TextSplitter"]]] = None,
    aggregation: str = "mean",
    drop_input: bool = False,
) -> DataFrame:
    """Add a column of vector embeddings using a LangChain ``Embeddings`` model.

    The function uses a Pandas UDF to batch calls to ``embed_documents`` and reuse a
    single embeddings instance per executor. Provide either an instantiated LangChain
    embeddings object or a zero-argument callable that returns one—factories are useful
    when clients (e.g., OpenAI) are not picklable. Optionally supply a LangChain text
    splitter to chunk long inputs before embedding; chunk embeddings are combined using
    ``aggregation`` (``"mean"`` or ``"first"``).

    Args:
        df: Input DataFrame containing the raw text column.
        input_col: Name of the column with text to embed.
        embeddings: LangChain embeddings instance or factory returning one.
        output_col: Name of the resulting column containing ``array<float>`` vectors.
        batch_size: Number of rows to embed per batch inside the UDF.
        text_splitter: Optional LangChain text splitter (or factory) applied before
            embedding to chunk the text.
        aggregation: Strategy to combine chunk embeddings when a splitter is provided.
            Supported values: ``"mean"`` (default) and ``"first"``.
        drop_input: Remove ``input_col`` from the resulting DataFrame when ``True``.

    Raises:
        ValueError: If ``input_col`` is missing, ``batch_size`` is not positive, the
            embeddings model returns a length mismatch, or ``aggregation`` is invalid.
        TypeError: When ``embeddings`` is neither an embeddings instance nor a factory
            producing one, or when ``text_splitter`` lacks ``split_text``.
        RuntimeError: When the embeddings model or text splitter raises an exception
            during execution. Spark surfaces these as ``pyspark.errors.PythonException``.
    """

    if input_col not in df.columns:
        raise ValueError(f"Column '{input_col}' not found in DataFrame")

    if not isinstance(batch_size, int) or batch_size <= 0:
        raise ValueError("batch_size must be a positive integer")

    agg_mode = aggregation.lower()
    if agg_mode not in {"mean", "first"}:
        raise ValueError("aggregation must be one of: 'mean', 'first'")

    def _resolve_embedder_factory() -> Callable[[], Any]:
        if hasattr(embeddings, "embed_documents"):
            return lambda: embeddings

        if callable(embeddings):

            def _factory():
                model = embeddings()
                if not hasattr(model, "embed_documents"):
                    raise TypeError(
                        "Embeddings factory must return an object with embed_documents()."
                    )
                return model

            # Validate once on the driver to surface configuration issues early.
            _factory()
            return _factory

        raise TypeError(
            "embeddings must be a LangChain Embeddings instance or a zero-argument factory."
        )

    def _resolve_splitter_factory():
        if text_splitter is None:
            return None
        if hasattr(text_splitter, "split_text"):
            return lambda: text_splitter
        if callable(text_splitter):

            def _factory():
                splitter_obj = text_splitter()
                if not hasattr(splitter_obj, "split_text"):
                    raise TypeError(
                        "Text splitter factory must return an object with split_text()."
                    )
                return splitter_obj

            _factory()
            return _factory
        raise TypeError(
            "text_splitter must be a LangChain TextSplitter instance or a zero-argument factory."
        )

    embedder_factory = _resolve_embedder_factory()
    splitter_factory = _resolve_splitter_factory()

    embedder_cache: Dict[str, Any] = {"model": None}
    splitter_cache: Dict[str, Any] = {"splitter": None}

    def _get_embedder():
        if embedder_cache["model"] is None:
            embedder_cache["model"] = embedder_factory()
        return embedder_cache["model"]

    def _get_splitter():
        if splitter_factory is None:
            return None
        if splitter_cache["splitter"] is None:
            splitter_cache["splitter"] = splitter_factory()
        return splitter_cache["splitter"]

    @pandas_udf(ArrayType(FloatType()))
    def _embed(text_series):
        import pandas as pd

        texts = ["" if value is None else str(value) for value in text_series.tolist()]
        embedder = _get_embedder()
        splitter = _get_splitter()

        flat_texts: list[str] = []
        counts: list[int] = []
        for value in texts:
            if splitter is None:
                chunks = [value]
            else:
                try:
                    chunks = splitter.split_text(value)
                except Exception as exc:
                    raise RuntimeError("Text splitter failed while processing input.") from exc
                if not chunks:
                    chunks = [value]

            flat_texts.extend(chunks)
            counts.append(len(chunks))

        vectors: list[Any] = []
        for start in range(0, len(flat_texts), batch_size):
            chunk = flat_texts[start : start + batch_size]
            try:
                chunk_vectors = embedder.embed_documents(chunk)
            except Exception as exc:
                raise RuntimeError(
                    f"LangChain embeddings failed for batch starting at index {start}"
                ) from exc

            if len(chunk_vectors) != len(chunk):
                raise ValueError(
                    "Embeddings model returned %s vectors for %s inputs"
                    % (len(chunk_vectors), len(chunk))
                )
            vectors.extend(chunk_vectors)

        aggregated: list[list[float]] = []
        cursor = 0

        def _aggregate_vectors(items: Sequence[Any]) -> list[float]:
            if not items:
                return []
            if agg_mode == "first":
                return [float(x) for x in items[0]]

            base = list(items[0])
            length = len(base)
            sums = [float(x) for x in base]
            for vec in items[1:]:
                if len(vec) != length:
                    raise ValueError("Embeddings model returned vectors of differing dimensions")
                for idx, val in enumerate(vec):
                    sums[idx] += float(val)
            count = float(len(items))
            return [val / count for val in sums]

        for count in counts:
            row_vectors = vectors[cursor : cursor + count]
            cursor += count
            aggregated.append(_aggregate_vectors(row_vectors))

        return pd.Series(aggregated)

    transformed = df.withColumn(output_col, _embed(F.col(input_col)))
    if drop_input:
        transformed = transformed.drop(input_col)
    return transformed

map_column_with_llm

map_column_with_llm(df: DataFrame, column: str, target_values: Union[Sequence[str], Mapping[str, Any]], *, model: str = 'gpt-3.5-turbo', dry_run: bool = False, max_retries: int = 3, request_timeout: int = 30, temperature: Optional[float] = 0.0) -> DataFrame

Map column values to target_values via a scalar PySpark UDF.

The transformation applies a regular user-defined function across the column, keeping a per-executor in-memory cache to avoid duplicate LLM calls. Spark accumulators track mapping statistics. When dry_run=True the UDF performs case-insensitive matching only and yields None for unmatched rows without contacting the LLM. When targeting models that require provider-managed sampling behaviour, set temperature=None to omit the temperature parameter from LLM requests.

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame whose values should be normalized.

required
column str

Source column containing the free-form text to map.

required
target_values Union[Sequence[str], Mapping[str, Any]]

List or mapping defining the set of canonical outputs. When a mapping is provided, its keys are treated as the canonical set.

required
model str

Chat model (or Azure deployment name) to query.

'gpt-3.5-turbo'
dry_run bool

Skip external calls and simply echo canonical matches (useful for smoke testing and cost estimation).

False
max_retries int

Retry budget passed to :func:_fetch_llm_mapping.

3
request_timeout int

Timeout in seconds for each HTTP request.

30
temperature Optional[float]

LLM sampling temperature. Use None to skip explicitly setting it (some provider models accept only their default temperature).

0.0

Returns:

Type Description
DataFrame

A new DataFrame with an additional <column>_mapped string column containing

DataFrame

the canonical value or None when no match is determined.

Raises:

Type Description
ValueError

If the source column is missing or target_values is empty.

TypeError

When target_values contains non-string entries.

Notes
  • The resulting DataFrame is cached to ensure logging the accumulator values does not trigger duplicate LLM requests.
  • Provide API credentials via the environment variables documented in :func:_get_llm_api_config before running with dry_run=False.
Source code in src/spark_fuse/utils/llm.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
def map_column_with_llm(
    df: DataFrame,
    column: str,
    target_values: Union[Sequence[str], Mapping[str, Any]],
    *,
    model: str = "gpt-3.5-turbo",
    dry_run: bool = False,
    max_retries: int = 3,
    request_timeout: int = 30,
    temperature: Optional[float] = 0.0,
) -> DataFrame:
    """Map ``column`` values to ``target_values`` via a scalar PySpark UDF.

    The transformation applies a regular user-defined function across the column, keeping
    a per-executor in-memory cache to avoid duplicate LLM calls. Spark accumulators track
    mapping statistics. When ``dry_run=True`` the UDF performs case-insensitive matching
    only and yields ``None`` for unmatched rows without contacting the LLM. When targeting
    models that require provider-managed sampling behaviour, set ``temperature=None`` to
    omit the ``temperature`` parameter from LLM requests.

    Args:
        df: Input DataFrame whose values should be normalized.
        column: Source column containing the free-form text to map.
        target_values: List or mapping defining the set of canonical outputs. When a
            mapping is provided, its keys are treated as the canonical set.
        model: Chat model (or Azure deployment name) to query.
        dry_run: Skip external calls and simply echo canonical matches (useful for smoke
            testing and cost estimation).
        max_retries: Retry budget passed to :func:`_fetch_llm_mapping`.
        request_timeout: Timeout in seconds for each HTTP request.
        temperature: LLM sampling temperature. Use ``None`` to skip explicitly setting it
            (some provider models accept only their default temperature).

    Returns:
        A new DataFrame with an additional ``<column>_mapped`` string column containing
        the canonical value or ``None`` when no match is determined.

    Raises:
        ValueError: If the source column is missing or ``target_values`` is empty.
        TypeError: When ``target_values`` contains non-string entries.

    Notes:
        - The resulting DataFrame is cached to ensure logging the accumulator values does
          not trigger duplicate LLM requests.
        - Provide API credentials via the environment variables documented in
          :func:`_get_llm_api_config` before running with ``dry_run=False``.
    """

    if column not in df.columns:
        raise ValueError(f"Column '{column}' not found in DataFrame")

    if isinstance(target_values, Mapping):
        targets = list(dict.fromkeys(target_values.keys()))
    else:
        targets = list(dict.fromkeys(target_values))

    if not targets:
        raise ValueError("target_values must contain at least one entry")

    if not all(isinstance(target, str) for target in targets):
        raise TypeError("target_values entries must be strings")

    lookup: Dict[str, str] = {target.lower(): target for target in targets}
    target_list = list(lookup.values())

    api_url: Optional[str] = None
    headers: Dict[str, str] = {}
    use_azure = False

    if not dry_run:
        api_url, headers, use_azure = _get_llm_api_config(model)

    spark = df.sparkSession
    sc = spark.sparkContext
    calls_acc = _create_long_accumulator(sc, f"llm_api_calls_{column}")
    mapped_acc = _create_long_accumulator(sc, f"mapped_entries_{column}")
    unmapped_acc = _create_long_accumulator(sc, f"unmapped_entries_{column}")

    new_col_name = f"{column}_mapped"

    def _make_mapper():
        cache: Dict[str, Optional[str]] = {}

        def _map_value(raw_value: Any) -> Optional[str]:
            if raw_value is None:
                unmapped_acc.add(1)
                return None

            value_str = str(raw_value)
            if value_str.strip() == "":
                unmapped_acc.add(1)
                return None

            if dry_run:
                mapped_value = lookup.get(value_str.lower())
                if mapped_value is None:
                    unmapped_acc.add(1)
                else:
                    mapped_acc.add(1)
                return mapped_value

            if value_str in cache:
                mapped_value = cache[value_str]
            else:
                calls_acc.add(1)
                mapped_candidate = _fetch_llm_mapping(
                    value_str,
                    target_list,
                    api_url=api_url,  # type: ignore[arg-type]
                    headers=headers,
                    use_azure=use_azure,
                    model=model,
                    max_retries=max_retries,
                    request_timeout=request_timeout,
                    temperature=temperature,
                )
                if mapped_candidate is not None:
                    mapped_value = lookup.get(mapped_candidate.lower(), mapped_candidate)
                else:
                    mapped_value = None
                cache[value_str] = mapped_value

            if mapped_value is None:
                unmapped_acc.add(1)
            else:
                mapped_acc.add(1)
            return mapped_value

        return _map_value

    mapper_udf = F.udf(_make_mapper(), StringType())

    mapped_df = df.withColumn(new_col_name, mapper_udf(F.col(column))).cache()
    mapped_df.count()

    mapped_count = mapped_acc.value
    unmapped_count = unmapped_acc.value
    _LOGGER.info(
        "Mapping stats for column '%s': Mapped %s, Unmapped %s, API calls made %s.",
        column,
        mapped_count,
        unmapped_count,
        calls_acc.value,
    )

    return mapped_df