Skip to content

Transformation Utilities

rename_columns

rename_columns(df: DataFrame, mapping: Mapping[str, str]) -> DataFrame

Rename columns according to mapping while preserving column order.

Raises:

Type Description
ValueError

If any source column is missing or the resulting columns collide.

Source code in src/spark_fuse/utils/transformations.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def rename_columns(df: DataFrame, mapping: Mapping[str, str]) -> DataFrame:
    """Rename columns according to ``mapping`` while preserving column order.

    Raises:
        ValueError: If any source column is missing or the resulting columns collide.
    """
    if not mapping:
        return df

    missing = [name for name in mapping if name not in df.columns]
    if missing:
        raise ValueError(f"Cannot rename missing columns: {missing}")

    final_names = [mapping.get(name, name) for name in df.columns]
    if len(final_names) != len(set(final_names)):
        raise ValueError("Renaming results in duplicate column names")

    renamed = []
    for name in df.columns:
        new_name = mapping.get(name, name)
        column = F.col(name)
        if new_name != name:
            column = column.alias(new_name)
        renamed.append(column)
    return df.select(*renamed)

with_constants

with_constants(df: DataFrame, constants: Mapping[str, Any], *, overwrite: bool = False) -> DataFrame

Add literal-valued columns using constants.

Parameters:

Name Type Description Default
constants Mapping[str, Any]

Mapping of column name to literal value.

required
overwrite bool

Replace existing columns when True (default False).

False

Raises:

Type Description
ValueError

If attempting to add an existing column without overwrite.

Source code in src/spark_fuse/utils/transformations.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def with_constants(
    df: DataFrame,
    constants: Mapping[str, Any],
    *,
    overwrite: bool = False,
) -> DataFrame:
    """Add literal-valued columns using ``constants``.

    Args:
        constants: Mapping of column name to literal value.
        overwrite: Replace existing columns when ``True`` (default ``False``).

    Raises:
        ValueError: If attempting to add an existing column without ``overwrite``.
    """
    if not constants:
        return df

    if not overwrite:
        duplicates = [name for name in constants if name in df.columns]
        if duplicates:
            raise ValueError(f"Columns already exist: {duplicates}")

    result = df
    for name, value in constants.items():
        result = result.withColumn(name, F.lit(value))
    return result

cast_columns

cast_columns(df: DataFrame, type_mapping: TypeMapping) -> DataFrame

Cast columns to new Spark SQL types.

The type_mapping values may be str or DataType instances.

Raises:

Type Description
ValueError

If any referenced column is missing.

Source code in src/spark_fuse/utils/transformations.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def cast_columns(df: DataFrame, type_mapping: TypeMapping) -> DataFrame:
    """Cast columns to new Spark SQL types.

    The ``type_mapping`` values may be ``str`` or ``DataType`` instances.

    Raises:
        ValueError: If any referenced column is missing.
    """
    if not type_mapping:
        return df

    missing = [name for name in type_mapping if name not in df.columns]
    if missing:
        raise ValueError(f"Cannot cast missing columns: {missing}")

    coerced = []
    for name in df.columns:
        if name in type_mapping:
            coerced.append(F.col(name).cast(type_mapping[name]).alias(name))
        else:
            coerced.append(F.col(name))
    return df.select(*coerced)

normalize_whitespace

normalize_whitespace(df: DataFrame, columns: Iterable[str], *, trim_ends: bool = True, pattern: str = _DEFAULT_REGEX, replacement: str = ' ') -> DataFrame

Collapse repeated whitespace in string columns.

Parameters:

Name Type Description Default
columns Iterable[str]

Iterable of column names to normalize. Duplicates are ignored.

required
trim_ends bool

When True, also trim the resulting string.

True
pattern str

Regex pattern to match; defaults to consecutive whitespace.

_DEFAULT_REGEX
replacement str

Replacement string for the regex matches.

' '

Raises:

Type Description
TypeError

If columns is provided as a single str.

ValueError

If any referenced column is missing.

Source code in src/spark_fuse/utils/transformations.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def normalize_whitespace(
    df: DataFrame,
    columns: Iterable[str],
    *,
    trim_ends: bool = True,
    pattern: str = _DEFAULT_REGEX,
    replacement: str = " ",
) -> DataFrame:
    """Collapse repeated whitespace in string columns.

    Args:
        columns: Iterable of column names to normalize. Duplicates are ignored.
        trim_ends: When ``True``, also ``trim`` the resulting string.
        pattern: Regex pattern to match; defaults to consecutive whitespace.
        replacement: Replacement string for the regex matches.

    Raises:
        TypeError: If ``columns`` is provided as a single ``str``.
        ValueError: If any referenced column is missing.
    """
    if isinstance(columns, str):
        raise TypeError("columns must be an iterable of column names, not a string")

    targets = list(dict.fromkeys(columns))
    if not targets:
        return df

    missing = [name for name in targets if name not in df.columns]
    if missing:
        raise ValueError(f"Cannot normalize missing columns: {missing}")

    result = df
    for name in targets:
        normalized = F.regexp_replace(F.col(name), pattern, replacement)
        if trim_ends:
            normalized = F.trim(normalized)
        result = result.withColumn(name, normalized)
    return result

split_by_date_formats

split_by_date_formats(df: DataFrame, column: str, formats: Iterable[str], *, handle_errors: str = 'null', default_value: Optional[str] = None, return_unmatched: bool = False, output_column: Optional[str] = None) -> Union[DataFrame, Tuple[DataFrame, DataFrame]]

Split df into per-format partitions with safely parsed date columns.

Parameters:

Name Type Description Default
column str

Name of the string column containing date representations.

required
formats Iterable[str]

Iterable of date format strings, evaluated in order.

required
handle_errors str

Strategy for unmatched rows ("null", "strict", "default").

'null'
default_value Optional[str]

Fallback date string when handle_errors="default".

None
return_unmatched bool

When True, also return the unmatched rows DataFrame.

False
output_column Optional[str]

Optional name for the parsed date column; defaults to f"{column}_date".

None

Returns:

Type Description
Union[DataFrame, Tuple[DataFrame, DataFrame]]

The combined DataFrame containing all parsed rows.

Union[DataFrame, Tuple[DataFrame, DataFrame]]

When return_unmatched is True, also returns the unmatched rows

Union[DataFrame, Tuple[DataFrame, DataFrame]]

DataFrame as a second element.

Raises:

Type Description
TypeError

If formats is a string or contains non-string entries.

ValueError

For missing columns, duplicate output column, invalid modes, or unmatched rows when in strict mode.

Source code in src/spark_fuse/utils/transformations.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def split_by_date_formats(
    df: DataFrame,
    column: str,
    formats: Iterable[str],
    *,
    handle_errors: str = "null",
    default_value: Optional[str] = None,
    return_unmatched: bool = False,
    output_column: Optional[str] = None,
) -> Union[DataFrame, Tuple[DataFrame, DataFrame]]:
    """Split ``df`` into per-format partitions with safely parsed date columns.

    Args:
        column: Name of the string column containing date representations.
        formats: Iterable of date format strings, evaluated in order.
        handle_errors: Strategy for unmatched rows (``"null"``, ``"strict"``, ``"default"``).
        default_value: Fallback date string when ``handle_errors="default"``.
        return_unmatched: When ``True``, also return the unmatched rows DataFrame.
        output_column: Optional name for the parsed date column; defaults to ``f"{column}_date"``.

    Returns:
        The combined DataFrame containing all parsed rows.

        When ``return_unmatched`` is ``True``, also returns the unmatched rows
        DataFrame as a second element.

    Raises:
        TypeError: If ``formats`` is a string or contains non-string entries.
        ValueError: For missing columns, duplicate output column, invalid modes, or
            unmatched rows when in ``strict`` mode.
    """

    if column not in df.columns:
        raise ValueError(f"Column '{column}' not found in DataFrame")

    parsed_column = output_column or f"{column}_date"
    if parsed_column in df.columns and parsed_column != column:
        raise ValueError(f"Output column '{parsed_column}' already exists")

    if isinstance(formats, str):
        raise TypeError("formats must be an iterable of strings, not a string")

    format_list = list(dict.fromkeys(formats))
    if not format_list:
        raise ValueError("At least one date format must be provided")

    if any(not isinstance(fmt, str) for fmt in format_list):
        raise TypeError("Each format must be a string")

    mode = handle_errors.lower()
    if mode not in _HANDLE_ERROR_MODES:
        raise ValueError(f"Unsupported handle_errors mode '{handle_errors}'")

    if mode == "default" and default_value is None:
        raise ValueError("default_value must be provided when handle_errors='default'")

    parsed_expressions = [F.to_date(F.col(column), fmt) for fmt in format_list]
    if len(parsed_expressions) == 1:
        parsed_expr = parsed_expressions[0]
    else:
        parsed_expr = F.coalesce(*parsed_expressions)

    format_idx_expr = None
    for idx, expr in enumerate(parsed_expressions):
        candidate = F.when(expr.isNotNull(), F.lit(idx))
        format_idx_expr = (
            candidate if format_idx_expr is None else format_idx_expr.otherwise(candidate)
        )
    if format_idx_expr is None:
        format_idx_expr = F.lit(None)

    format_idx_column = f"__{column}_format_idx__"
    while format_idx_column in df.columns:
        format_idx_column = f"_{format_idx_column}"

    df_with_meta = df.withColumn(parsed_column, parsed_expr).withColumn(
        format_idx_column, format_idx_expr
    )

    partitions: list[DataFrame] = []
    for idx, _ in enumerate(format_list):
        group_df = df_with_meta.filter(F.col(format_idx_column) == idx).drop(format_idx_column)
        partitions.append(group_df)

    unmatched_df = df_with_meta.filter(F.col(format_idx_column).isNull()).drop(format_idx_column)

    if mode == "strict":
        if unmatched_df.limit(1).collect():
            raise ValueError("Unmatched rows detected while handle_errors='strict'")
    elif mode == "default":
        default_df = unmatched_df.withColumn(parsed_column, F.lit(default_value).cast("date"))
        partitions.append(default_df)
    else:
        partitions.append(unmatched_df)

    result_df = partitions[0]
    for part in partitions[1:]:
        result_df = result_df.unionByName(part)

    result: Union[DataFrame, Tuple[DataFrame, DataFrame]] = result_df
    if return_unmatched:
        result = (result_df, unmatched_df)
    return result

map_column_with_llm

map_column_with_llm(df: DataFrame, column: str, target_values: Union[Sequence[str], Mapping[str, Any]], *, model: str = 'gpt-3.5-turbo', dry_run: bool = False, max_retries: int = 3, request_timeout: int = 30, temperature: Optional[float] = 0.0) -> DataFrame

Map column values to target_values via a scalar PySpark UDF.

The transformation applies a regular user-defined function across the column, keeping a per-executor in-memory cache to avoid duplicate LLM calls. Spark accumulators track mapping statistics. When dry_run=True the UDF performs case-insensitive matching only and yields None for unmatched rows without contacting the LLM. When targeting models that require provider-managed sampling behaviour, set temperature=None to omit the temperature parameter from LLM requests.

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame whose values should be normalized.

required
column str

Source column containing the free-form text to map.

required
target_values Union[Sequence[str], Mapping[str, Any]]

List or mapping defining the set of canonical outputs. When a mapping is provided, its keys are treated as the canonical set.

required
model str

Chat model (or Azure deployment name) to query.

'gpt-3.5-turbo'
dry_run bool

Skip external calls and simply echo canonical matches (useful for smoke testing and cost estimation).

False
max_retries int

Retry budget passed to :func:_fetch_llm_mapping.

3
request_timeout int

Timeout in seconds for each HTTP request.

30
temperature Optional[float]

LLM sampling temperature. Use None to skip explicitly setting it (some provider models accept only their default temperature).

0.0

Returns:

Type Description
DataFrame

A new DataFrame with an additional <column>_mapped string column containing

DataFrame

the canonical value or None when no match is determined.

Raises:

Type Description
ValueError

If the source column is missing or target_values is empty.

TypeError

When target_values contains non-string entries.

Notes
  • The resulting DataFrame is cached to ensure logging the accumulator values does not trigger duplicate LLM requests.
  • Provide API credentials via the environment variables documented in :func:_get_llm_api_config before running with dry_run=False.
Source code in src/spark_fuse/utils/transformations.py
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
def map_column_with_llm(
    df: DataFrame,
    column: str,
    target_values: Union[Sequence[str], Mapping[str, Any]],
    *,
    model: str = "gpt-3.5-turbo",
    dry_run: bool = False,
    max_retries: int = 3,
    request_timeout: int = 30,
    temperature: Optional[float] = 0.0,
) -> DataFrame:
    """Map ``column`` values to ``target_values`` via a scalar PySpark UDF.

    The transformation applies a regular user-defined function across the column, keeping
    a per-executor in-memory cache to avoid duplicate LLM calls. Spark accumulators track
    mapping statistics. When ``dry_run=True`` the UDF performs case-insensitive matching
    only and yields ``None`` for unmatched rows without contacting the LLM. When targeting
    models that require provider-managed sampling behaviour, set ``temperature=None`` to
    omit the ``temperature`` parameter from LLM requests.

    Args:
        df: Input DataFrame whose values should be normalized.
        column: Source column containing the free-form text to map.
        target_values: List or mapping defining the set of canonical outputs. When a
            mapping is provided, its keys are treated as the canonical set.
        model: Chat model (or Azure deployment name) to query.
        dry_run: Skip external calls and simply echo canonical matches (useful for smoke
            testing and cost estimation).
        max_retries: Retry budget passed to :func:`_fetch_llm_mapping`.
        request_timeout: Timeout in seconds for each HTTP request.
        temperature: LLM sampling temperature. Use ``None`` to skip explicitly setting it
            (some provider models accept only their default temperature).

    Returns:
        A new DataFrame with an additional ``<column>_mapped`` string column containing
        the canonical value or ``None`` when no match is determined.

    Raises:
        ValueError: If the source column is missing or ``target_values`` is empty.
        TypeError: When ``target_values`` contains non-string entries.

    Notes:
        - The resulting DataFrame is cached to ensure logging the accumulator values does
          not trigger duplicate LLM requests.
        - Provide API credentials via the environment variables documented in
          :func:`_get_llm_api_config` before running with ``dry_run=False``.
    """

    if column not in df.columns:
        raise ValueError(f"Column '{column}' not found in DataFrame")

    if isinstance(target_values, Mapping):
        targets = list(dict.fromkeys(target_values.keys()))
    else:
        targets = list(dict.fromkeys(target_values))

    if not targets:
        raise ValueError("target_values must contain at least one entry")

    if not all(isinstance(target, str) for target in targets):
        raise TypeError("target_values entries must be strings")

    lookup: Dict[str, str] = {target.lower(): target for target in targets}
    target_list = list(lookup.values())

    api_url: Optional[str] = None
    headers: Dict[str, str] = {}
    use_azure = False

    if not dry_run:
        api_url, headers, use_azure = _get_llm_api_config(model)

    spark = df.sparkSession
    sc = spark.sparkContext
    calls_acc = _create_long_accumulator(sc, f"llm_api_calls_{column}")
    mapped_acc = _create_long_accumulator(sc, f"mapped_entries_{column}")
    unmapped_acc = _create_long_accumulator(sc, f"unmapped_entries_{column}")

    new_col_name = f"{column}_mapped"

    def _make_mapper():
        cache: Dict[str, Optional[str]] = {}

        def _map_value(raw_value: Any) -> Optional[str]:
            if raw_value is None:
                unmapped_acc.add(1)
                return None

            value_str = str(raw_value)
            if value_str.strip() == "":
                unmapped_acc.add(1)
                return None

            if dry_run:
                mapped_value = lookup.get(value_str.lower())
                if mapped_value is None:
                    unmapped_acc.add(1)
                else:
                    mapped_acc.add(1)
                return mapped_value

            if value_str in cache:
                mapped_value = cache[value_str]
            else:
                calls_acc.add(1)
                mapped_candidate = _fetch_llm_mapping(
                    value_str,
                    target_list,
                    api_url=api_url,  # type: ignore[arg-type]
                    headers=headers,
                    use_azure=use_azure,
                    model=model,
                    max_retries=max_retries,
                    request_timeout=request_timeout,
                    temperature=temperature,
                )
                if mapped_candidate is not None:
                    mapped_value = lookup.get(mapped_candidate.lower(), mapped_candidate)
                else:
                    mapped_value = None
                cache[value_str] = mapped_value

            if mapped_value is None:
                unmapped_acc.add(1)
            else:
                mapped_acc.add(1)
            return mapped_value

        return _map_value

    mapper_udf = F.udf(_make_mapper(), StringType())

    mapped_df = df.withColumn(new_col_name, mapper_udf(F.col(column))).cache()
    mapped_df.count()

    mapped_count = mapped_acc.value
    unmapped_count = unmapped_acc.value
    logger.info(
        "Mapping stats for column '%s': Mapped %s, Unmapped %s, API calls made %s.",
        column,
        mapped_count,
        unmapped_count,
        calls_acc.value,
    )

    return mapped_df