Skip to content

Transformation Utilities

Pure DataFrame column transformations.

LLM-powered transformations (with_langchain_embeddings, map_column_with_llm) live in :mod:spark_fuse.utils.llm.

rename_columns

rename_columns(df: DataFrame, mapping: Mapping[str, str]) -> DataFrame

Rename columns according to mapping while preserving column order.

Raises:

Type Description
ValueError

If any source column is missing or the resulting columns collide.

Source code in src/spark_fuse/utils/transformations.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def rename_columns(df: DataFrame, mapping: Mapping[str, str]) -> DataFrame:
    """Rename columns according to ``mapping`` while preserving column order.

    Raises:
        ValueError: If any source column is missing or the resulting columns collide.
    """
    if not mapping:
        return df

    missing = [name for name in mapping if name not in df.columns]
    if missing:
        raise ValueError(f"Cannot rename missing columns: {missing}")

    final_names = [mapping.get(name, name) for name in df.columns]
    if len(final_names) != len(set(final_names)):
        raise ValueError("Renaming results in duplicate column names")

    renamed = []
    for name in df.columns:
        new_name = mapping.get(name, name)
        column = F.col(name)
        if new_name != name:
            column = column.alias(new_name)
        renamed.append(column)
    return df.select(*renamed)

with_constants

with_constants(df: DataFrame, constants: Mapping[str, Any], *, overwrite: bool = False) -> DataFrame

Add literal-valued columns using constants.

Parameters:

Name Type Description Default
constants Mapping[str, Any]

Mapping of column name to literal value.

required
overwrite bool

Replace existing columns when True (default False).

False

Raises:

Type Description
ValueError

If attempting to add an existing column without overwrite.

Source code in src/spark_fuse/utils/transformations.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def with_constants(
    df: DataFrame,
    constants: Mapping[str, Any],
    *,
    overwrite: bool = False,
) -> DataFrame:
    """Add literal-valued columns using ``constants``.

    Args:
        constants: Mapping of column name to literal value.
        overwrite: Replace existing columns when ``True`` (default ``False``).

    Raises:
        ValueError: If attempting to add an existing column without ``overwrite``.
    """
    if not constants:
        return df

    if not overwrite:
        duplicates = [name for name in constants if name in df.columns]
        if duplicates:
            raise ValueError(f"Columns already exist: {duplicates}")

    result = df
    for name, value in constants.items():
        result = result.withColumn(name, F.lit(value))
    return result

cast_columns

cast_columns(df: DataFrame, type_mapping: TypeMapping) -> DataFrame

Cast columns to new Spark SQL types.

The type_mapping values may be str or DataType instances.

Raises:

Type Description
ValueError

If any referenced column is missing.

Source code in src/spark_fuse/utils/transformations.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def cast_columns(df: DataFrame, type_mapping: TypeMapping) -> DataFrame:
    """Cast columns to new Spark SQL types.

    The ``type_mapping`` values may be ``str`` or ``DataType`` instances.

    Raises:
        ValueError: If any referenced column is missing.
    """
    if not type_mapping:
        return df

    missing = [name for name in type_mapping if name not in df.columns]
    if missing:
        raise ValueError(f"Cannot cast missing columns: {missing}")

    coerced = []
    for name in df.columns:
        if name in type_mapping:
            coerced.append(F.col(name).cast(type_mapping[name]).alias(name))
        else:
            coerced.append(F.col(name))
    return df.select(*coerced)

normalize_whitespace

normalize_whitespace(df: DataFrame, columns: Iterable[str], *, trim_ends: bool = True, pattern: str = _DEFAULT_REGEX, replacement: str = ' ') -> DataFrame

Collapse repeated whitespace in string columns.

Parameters:

Name Type Description Default
columns Iterable[str]

Iterable of column names to normalize. Duplicates are ignored.

required
trim_ends bool

When True, also trim the resulting string.

True
pattern str

Regex pattern to match; defaults to consecutive whitespace.

_DEFAULT_REGEX
replacement str

Replacement string for the regex matches.

' '

Raises:

Type Description
TypeError

If columns is provided as a single str.

ValueError

If any referenced column is missing.

Source code in src/spark_fuse/utils/transformations.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def normalize_whitespace(
    df: DataFrame,
    columns: Iterable[str],
    *,
    trim_ends: bool = True,
    pattern: str = _DEFAULT_REGEX,
    replacement: str = " ",
) -> DataFrame:
    """Collapse repeated whitespace in string columns.

    Args:
        columns: Iterable of column names to normalize. Duplicates are ignored.
        trim_ends: When ``True``, also ``trim`` the resulting string.
        pattern: Regex pattern to match; defaults to consecutive whitespace.
        replacement: Replacement string for the regex matches.

    Raises:
        TypeError: If ``columns`` is provided as a single ``str``.
        ValueError: If any referenced column is missing.
    """
    if isinstance(columns, str):
        raise TypeError("columns must be an iterable of column names, not a string")

    targets = list(dict.fromkeys(columns))
    if not targets:
        return df

    missing = [name for name in targets if name not in df.columns]
    if missing:
        raise ValueError(f"Cannot normalize missing columns: {missing}")

    result = df
    for name in targets:
        normalized = F.regexp_replace(F.col(name), pattern, replacement)
        if trim_ends:
            normalized = F.trim(normalized)
        result = result.withColumn(name, normalized)
    return result

split_by_date_formats

split_by_date_formats(df: DataFrame, column: str, formats: Iterable[str], *, handle_errors: str = 'null', default_value: Optional[str] = None, return_unmatched: bool = False, output_column: Optional[str] = None) -> Union[DataFrame, Tuple[DataFrame, DataFrame]]

Split df into per-format partitions with safely parsed date columns.

Parameters:

Name Type Description Default
column str

Name of the string column containing date representations.

required
formats Iterable[str]

Iterable of date format strings, evaluated in order.

required
handle_errors str

Strategy for unmatched rows ("null", "strict", "default").

'null'
default_value Optional[str]

Fallback date string when handle_errors="default".

None
return_unmatched bool

When True, also return the unmatched rows DataFrame.

False
output_column Optional[str]

Optional name for the parsed date column; defaults to f"{column}_date".

None

Returns:

Type Description
Union[DataFrame, Tuple[DataFrame, DataFrame]]

The combined DataFrame containing all parsed rows.

Union[DataFrame, Tuple[DataFrame, DataFrame]]

When return_unmatched is True, also returns the unmatched rows

Union[DataFrame, Tuple[DataFrame, DataFrame]]

DataFrame as a second element.

Raises:

Type Description
TypeError

If formats is a string or contains non-string entries.

ValueError

For missing columns, duplicate output column, invalid modes, or unmatched rows when in strict mode.

Source code in src/spark_fuse/utils/transformations.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def split_by_date_formats(
    df: DataFrame,
    column: str,
    formats: Iterable[str],
    *,
    handle_errors: str = "null",
    default_value: Optional[str] = None,
    return_unmatched: bool = False,
    output_column: Optional[str] = None,
) -> Union[DataFrame, Tuple[DataFrame, DataFrame]]:
    """Split ``df`` into per-format partitions with safely parsed date columns.

    Args:
        column: Name of the string column containing date representations.
        formats: Iterable of date format strings, evaluated in order.
        handle_errors: Strategy for unmatched rows (``"null"``, ``"strict"``, ``"default"``).
        default_value: Fallback date string when ``handle_errors="default"``.
        return_unmatched: When ``True``, also return the unmatched rows DataFrame.
        output_column: Optional name for the parsed date column; defaults to ``f"{column}_date"``.

    Returns:
        The combined DataFrame containing all parsed rows.

        When ``return_unmatched`` is ``True``, also returns the unmatched rows
        DataFrame as a second element.

    Raises:
        TypeError: If ``formats`` is a string or contains non-string entries.
        ValueError: For missing columns, duplicate output column, invalid modes, or
            unmatched rows when in ``strict`` mode.
    """

    if column not in df.columns:
        raise ValueError(f"Column '{column}' not found in DataFrame")

    parsed_column = output_column or f"{column}_date"
    if parsed_column in df.columns and parsed_column != column:
        raise ValueError(f"Output column '{parsed_column}' already exists")

    if isinstance(formats, str):
        raise TypeError("formats must be an iterable of strings, not a string")

    format_list = list(dict.fromkeys(formats))
    if not format_list:
        raise ValueError("At least one date format must be provided")

    if any(not isinstance(fmt, str) for fmt in format_list):
        raise TypeError("Each format must be a string")

    mode = handle_errors.lower()
    if mode not in _HANDLE_ERROR_MODES:
        raise ValueError(f"Unsupported handle_errors mode '{handle_errors}'")

    if mode == "default" and default_value is None:
        raise ValueError("default_value must be provided when handle_errors='default'")

    def _parsed_date_expr(fmt: str):
        """Return a date column that tolerates parse errors when possible."""

        if hasattr(F, "try_to_timestamp"):
            # `try_to_timestamp` expects the format as a column/literal, not a Python string.
            return F.to_date(F.try_to_timestamp(F.col(column), F.lit(fmt)))
        return F.to_date(F.col(column), fmt)

    parsed_expressions = [_parsed_date_expr(fmt) for fmt in format_list]
    if len(parsed_expressions) == 1:
        parsed_expr = parsed_expressions[0]
    else:
        parsed_expr = F.coalesce(*parsed_expressions)

    format_idx_expr = F.lit(None)
    for idx, expr in reversed(list(enumerate(parsed_expressions))):
        format_idx_expr = F.when(expr.isNotNull(), F.lit(idx)).otherwise(format_idx_expr)

    format_idx_column = f"__{column}_format_idx__"
    while format_idx_column in df.columns:
        format_idx_column = f"_{format_idx_column}"

    df_with_meta = df.withColumn(parsed_column, parsed_expr).withColumn(
        format_idx_column, format_idx_expr
    )

    partitions: list[DataFrame] = []
    for idx, _ in enumerate(format_list):
        group_df = df_with_meta.filter(F.col(format_idx_column) == idx).drop(format_idx_column)
        partitions.append(group_df)

    unmatched_df = df_with_meta.filter(F.col(format_idx_column).isNull()).drop(format_idx_column)

    if mode == "strict":
        if unmatched_df.limit(1).collect():
            raise ValueError("Unmatched rows detected while handle_errors='strict'")
    elif mode == "default":
        default_df = unmatched_df.withColumn(parsed_column, F.lit(default_value).cast("date"))
        partitions.append(default_df)
    else:
        partitions.append(unmatched_df)

    result_df = partitions[0]
    for part in partitions[1:]:
        result_df = result_df.unionByName(part)

    result: Union[DataFrame, Tuple[DataFrame, DataFrame]] = result_df
    if return_unmatched:
        result = (result_df, unmatched_df)
    return result