Skip to content

Transformations

flatten_dataframe(df, separator=':', replace_char='_', sanitized_columns=False)

Flattens the complex columns in the DataFrame.

Parameters:

Name Type Description Default
df DataFrame

The input PySpark DataFrame.

required
separator str

The separator to use in the resulting flattened column names, defaults to ":".

':'
replace_char str

The character to replace special characters with in column names, defaults to "_".

'_'
sanitized_columns bool

Whether to sanitize column names, defaults to False.

False

Returns:

Type Description
DataFrame .. note:: This function assumes the input DataFrame has a consistent schema across all rows. If you have files with different schemas, process each separately instead. .. example:: Example usage: >>> data = [ ( 1, ("Alice", 25), {"A": 100, "B": 200}, ["apple", "banana"], {"key": {"nested_key": 10}}, {"A#": 1000, "B@": 2000}, ), ( 2, ("Bob", 30), {"A": 150, "B": 250}, ["orange", "grape"], {"key": {"nested_key": 20}}, {"A#": 1500, "B@": 2500}, ), ] >>> df = spark.createDataFrame(data) >>> flattened_df = flatten_dataframe(df) >>> flattened_df.show() >>> flattened_df_with_hyphen = flatten_dataframe(df, replace_char="-") >>> flattened_df_with_hyphen.show()

The DataFrame with all complex data types flattened.

Source code in quinn/transformations.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def flatten_dataframe(
    df: DataFrame,
    separator: str = ":",
    replace_char: str = "_",
    sanitized_columns: bool = False,
) -> DataFrame:
    """Flattens the complex columns in the DataFrame.

    :param df: The input PySpark DataFrame.
    :type df: DataFrame
    :param separator: The separator to use in the resulting flattened column names, defaults to ":".
    :type separator: str, optional
    :param replace_char: The character to replace special characters with in column names, defaults to "_".
    :type replace_char: str, optional
    :param sanitized_columns: Whether to sanitize column names, defaults to False.
    :type sanitized_columns: bool, optional
    :return: The DataFrame with all complex data types flattened.
    :rtype: DataFrame

    .. note:: This function assumes the input DataFrame has a consistent schema across all rows. If you have files with
        different schemas, process each separately instead.

    .. example:: Example usage:

        >>> data = [
                (
                    1,
                    ("Alice", 25),
                    {"A": 100, "B": 200},
                    ["apple", "banana"],
                    {"key": {"nested_key": 10}},
                    {"A#": 1000, "B@": 2000},
                ),
                (
                    2,
                    ("Bob", 30),
                    {"A": 150, "B": 250},
                    ["orange", "grape"],
                    {"key": {"nested_key": 20}},
                    {"A#": 1500, "B@": 2500},
                ),
            ]

        >>> df = spark.createDataFrame(data)
        >>> flattened_df = flatten_dataframe(df)
        >>> flattened_df.show()
        >>> flattened_df_with_hyphen = flatten_dataframe(df, replace_char="-")
        >>> flattened_df_with_hyphen.show()
    """

    def sanitize_column_name(name: str, rc: str = "_") -> str:
        """Sanitizes column names by replacing special characters with the specified character.

        :param name: The original column name.
        :type name: str
        :param rc: The character to replace special characters with, defaults to '_'.
        :type rc: str, optional
        :return: The sanitized column name.
        :rtype: str
        """
        return re.sub(r"[^a-zA-Z0-9_]", rc, name)

    def explode_array(df: DataFrame, col_name: str) -> DataFrame:
        """Explodes the specified ArrayType column in the input DataFrame and returns a new DataFrame with the exploded column.

        :param df: The input PySpark DataFrame.
        :type df: DataFrame
        :param col_name: The column name of the ArrayType to be exploded.
        :type col_name: str
        :return: The DataFrame with the exploded ArrayType column.
        :rtype: DataFrame
        """
        return df.select(
            "*",
            F.explode_outer(F.col(f"`{col_name}`")).alias(col_name),
        ).drop(
            col_name,
        )

    fields = complex_fields(df.schema)

    while len(fields) != 0:
        col_name = next(iter(fields.keys()))

        if isinstance(fields[col_name], StructType):
            df = flatten_struct(df, col_name, separator)  # noqa: PD901

        elif isinstance(fields[col_name], ArrayType):
            df = explode_array(df, col_name)  # noqa: PD901

        elif isinstance(fields[col_name], MapType):
            df = flatten_map(df, col_name, separator)  # noqa: PD901

        fields = complex_fields(df.schema)

    # Sanitize column names with the specified replace_char
    if sanitized_columns:
        sanitized_columns = [
            sanitize_column_name(col_name, replace_char) for col_name in df.columns
        ]
        df = df.toDF(*sanitized_columns)  # noqa: PD901

    return df

flatten_map(df, col_name, separator=':')

Flattens the specified MapType column in the input DataFrame and returns a new DataFrame with the flattened columns.

Parameters:

Name Type Description Default
df DataFrame

The input PySpark DataFrame.

required
col_name str

The column name of the MapType to be flattened.

required
separator str

The separator to use in the resulting flattened column names, defaults to ":".

':'

Returns:

Type Description
DataFrame

The DataFrame with the flattened MapType column.

Source code in quinn/transformations.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def flatten_map(df: DataFrame, col_name: str, separator: str = ":") -> DataFrame:
    """Flattens the specified MapType column in the input DataFrame and returns a new DataFrame with the flattened columns.

    :param df: The input PySpark DataFrame.
    :type df: DataFrame
    :param col_name: The column name of the MapType to be flattened.
    :type col_name: str
    :param separator: The separator to use in the resulting flattened column names, defaults to ":".
    :type separator: str, optional
    :return: The DataFrame with the flattened MapType column.
    :rtype: DataFrame
    """
    keys_df = df.select(F.explode_outer(F.map_keys(F.col(f"`{col_name}`")))).distinct()
    keys = [row[0] for row in keys_df.collect()]
    key_cols = [
        F.col(f"`{col_name}`").getItem(k).alias(col_name + separator + k) for k in keys
    ]
    return df.select(
        [F.col(f"`{col}`") for col in df.columns if col != col_name] + key_cols,
    )

flatten_struct(df, col_name, separator=':')

Flattens the specified StructType column in the input DataFrame and returns a new DataFrame with the flattened columns.

Parameters:

Name Type Description Default
df DataFrame

The input PySpark DataFrame.

required
col_name str

The column name of the StructType to be flattened.

required
separator str

The separator to use in the resulting flattened column names, defaults to ':'.

':'

Returns:

Type Description
List[Column]

The DataFrame with the flattened StructType column.

Source code in quinn/transformations.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def flatten_struct(df: DataFrame, col_name: str, separator: str = ":") -> DataFrame:
    """Flattens the specified StructType column in the input DataFrame and returns a new DataFrame with the flattened columns.

    :param df: The input PySpark DataFrame.
    :type df: DataFrame
    :param col_name: The column name of the StructType to be flattened.
    :type col_name: str
    :param separator: The separator to use in the resulting flattened column names, defaults to ':'.
    :type separator: str, optional
    :return: The DataFrame with the flattened StructType column.
    :rtype: List[Column]
    """
    struct_type = complex_fields(df.schema)[col_name]
    expanded = [
        F.col(f"`{col_name}`.`{k}`").alias(col_name + separator + k)
        for k in [n.name for n in struct_type.fields]
    ]
    return df.select("*", *expanded).drop(F.col(f"`{col_name}`"))

snake_case_col_names(df)

Function takes a DataFrame instance and returns the same DataFrame instance with all column names converted to snake case.

(e.g. col_name_1). It uses the to_snake_case function in conjunction with the with_columns_renamed function to achieve this.

Parameters:

Name Type Description Default
df DataFrame

A DataFrame instance to process

required

Returns:

Type Description
``DataFrame``.

A DataFrame instance with column names converted to snake case

Source code in quinn/transformations.py
65
66
67
68
69
70
71
72
73
74
75
def snake_case_col_names(df: DataFrame) -> DataFrame:
    """Function takes a ``DataFrame`` instance and returns the same ``DataFrame`` instance with all column names converted to snake case.

    (e.g. ``col_name_1``). It uses the ``to_snake_case`` function in conjunction with
    the ``with_columns_renamed`` function to achieve this.
    :param df: A ``DataFrame`` instance to process
    :type df: ``DataFrame``
    :return: A ``DataFrame`` instance with column names converted to snake case
    :rtype: ``DataFrame``.
    """
    return with_columns_renamed(to_snake_case)(df)

sort_columns(df, sort_order, sort_nested=False)

This function sorts the columns of a given DataFrame based on a given sort order. The sort_order parameter can either be asc or desc, which correspond to ascending and descending order, respectively. If any other value is provided for the sort_order parameter, a ValueError will be raised.

Parameters:

Name Type Description Default
df DataFrame

A DataFrame

required
sort_order str

The order in which to sort the columns in the DataFrame

required
sort_nested bool

Whether to sort nested structs or not. Defaults to false.

False

Returns:

Type Description
pyspark.sql.DataFrame

A DataFrame with the columns sorted in the chosen order

Source code in quinn/transformations.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def sort_columns(  # noqa: C901,PLR0915
    df: DataFrame,
    sort_order: str,
    sort_nested: bool = False,
) -> DataFrame:
    """This function sorts the columns of a given DataFrame based on a given sort
    order. The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to
    ascending and descending order, respectively. If any other value is provided for
    the ``sort_order`` parameter, a ``ValueError`` will be raised.

    :param df: A DataFrame
    :type df: pyspark.sql.DataFrame
    :param sort_order: The order in which to sort the columns in the DataFrame
    :type sort_order: str
    :param sort_nested: Whether to sort nested structs or not. Defaults to false.
    :type sort_nested: bool
    :return: A DataFrame with the columns sorted in the chosen order
    :rtype: pyspark.sql.DataFrame
    """

    def sort_nested_cols(
        schema: StructType, is_reversed: bool, base_field: str="",
    ) -> list[str]:
        # recursively check nested fields and sort them
        # https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark
        # Credits: @pault for logic

        def parse_fields(
            fields_to_sort: list,
            parent_struct: StructType,
            is_reversed: bool,
        ) -> list:
            sorted_fields: list = sorted(
                fields_to_sort,
                key=lambda x: x["name"],
                reverse=is_reversed,
            )

            results = []
            for field in sorted_fields:
                new_struct = StructType([StructField.fromJson(field)])
                new_base_field = parent_struct.name
                if base_field:
                    new_base_field = base_field + "." + new_base_field

                results.extend(
                    sort_nested_cols(
                        new_struct, is_reversed, base_field=new_base_field,
                    ),
                )
            return results

        select_cols = []
        for parent_struct in sorted(schema, key=lambda x: x.name, reverse=is_reversed):
            field_type = parent_struct.dataType
            if isinstance(field_type, ArrayType):
                array_parent = parent_struct.jsonValue()["type"]["elementType"]
                base_str = f"transform({parent_struct.name}"
                suffix_str = f") AS {parent_struct.name}"

                # if struct in array, create mapping to struct
                if array_parent["type"] == "struct":
                    array_parent = array_parent["fields"]
                    base_str = f"{base_str}, x -> struct("
                    suffix_str = f"){suffix_str}"

                array_elements = parse_fields(array_parent, parent_struct, is_reversed)
                element_names = [i.split(".")[-1] for i in array_elements]
                array_elements_formatted = [f"x.{i} as {i}" for i in element_names]

                # create a string representation of the sorted array
                # ex: transform(phone_numbers, x -> struct(x.number as number, x.type as type)) AS phone_numbers
                result = f"{base_str}{', '.join(array_elements_formatted)}{suffix_str}"

            elif isinstance(field_type, StructType):
                field_list = parent_struct.jsonValue()["type"]["fields"]
                sub_fields = parse_fields(field_list, parent_struct, is_reversed)

                # create a string representation of the sorted struct
                # ex: struct(address.zip.first5, address.zip.last4) AS zip
                result = f"struct({', '.join(sub_fields)}) AS {parent_struct.name}"

            elif base_field:
                result = f"{base_field}.{parent_struct.name}"
            else:
                result = parent_struct.name
            select_cols.append(result)

        return select_cols

    def get_original_nullability(field: StructField, result_dict: dict) -> None:
        if hasattr(field, "nullable"):
            result_dict[field.name] = field.nullable
        else:
            result_dict[field.name] = True

        if not isinstance(field.dataType, StructType) and not isinstance(
            field.dataType,
            ArrayType,
        ):
            return

        if isinstance(field.dataType, ArrayType):
            result_dict[f"{field.name}_element"] = field.dataType.containsNull
            children = field.dataType.elementType.fields
        else:
            children = field.dataType.fields
        for i in children:
            get_original_nullability(i, result_dict)

    def fix_nullability(field: StructField, result_dict: dict) -> None:
        field.nullable = result_dict[field.name]
        if not isinstance(field.dataType, StructType) and not isinstance(
            field.dataType,
            ArrayType,
        ):
            return

        if isinstance(field.dataType, ArrayType):
            # save the containsNull property of the ArrayType
            field.dataType.containsNull = result_dict[f"{field.name}_element"]
            children = field.dataType.elementType.fields
        else:
            children = field.dataType.fields

        for i in children:
            fix_nullability(i, result_dict)

    if sort_order not in ["asc", "desc"]:
        msg = f"['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'"
        raise ValueError(
            msg,
        )
    reverse_lookup = {
        "asc": False,
        "desc": True,
    }

    is_reversed: bool = reverse_lookup[sort_order]
    top_level_sorted_df = df.select(*sorted(df.columns, reverse=is_reversed))
    if not sort_nested:
        return top_level_sorted_df

    is_nested: bool = any(
        isinstance(i.dataType, (StructType, ArrayType))
        for i in top_level_sorted_df.schema
    )

    if not is_nested:
        return top_level_sorted_df

    fully_sorted_schema = sort_nested_cols(top_level_sorted_df.schema, is_reversed)
    output = df.selectExpr(fully_sorted_schema)
    result_dict = {}
    for field in df.schema:
        get_original_nullability(field, result_dict)

    for field in output.schema:
        fix_nullability(field, result_dict)

    if not hasattr(SparkSession, "getActiveSession"):  # spark 2.4
        spark = SparkSession.builder.getOrCreate()
    else:
        spark = SparkSession.getActiveSession()
        spark = spark if spark is not None else SparkSession.builder.getOrCreate()

    return output

to_snake_case(s)

Takes a string and converts it to snake case format.

Parameters:

Name Type Description Default
s str

The string to be converted.

required

Returns:

Type Description
str

The string in snake case format.

Source code in quinn/transformations.py
78
79
80
81
82
83
84
85
86
def to_snake_case(s: str) -> str:
    """Takes a string and converts it to snake case format.

    :param s: The string to be converted.
    :type s: str
    :return: The string in snake case format.
    :rtype: str
    """
    return s.lower().replace(" ", "_")

with_columns_renamed(fun)

Function designed to rename the columns of a Spark DataFrame.

It takes a Callable[[str], str] object as an argument (fun) and returns a Callable[[DataFrame], DataFrame] object.

When _() is called on a DataFrame, it creates a list of column names, applying the argument fun() to each of them, and returning a new DataFrame with the new column names.

Parameters:

Name Type Description Default
fun Callable[[str], str]

Renaming function

required

Returns:

Type Description
Callable[[DataFrame], DataFrame]

Function which takes DataFrame as parameter.

Source code in quinn/transformations.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def with_columns_renamed(fun: Callable[[str], str]) -> Callable[[DataFrame], DataFrame]:
    """Function designed to rename the columns of a `Spark DataFrame`.

    It takes a `Callable[[str], str]` object as an argument (``fun``) and returns a
    `Callable[[DataFrame], DataFrame]` object.

    When `_()` is called on a `DataFrame`, it creates a list of column names,
    applying the argument `fun()` to each of them, and returning a new `DataFrame`
    with the new column names.

    :param fun: Renaming function
    :returns: Function which takes DataFrame as parameter.
    """

    def _(df: DataFrame) -> DataFrame:
        cols = [F.col(f"`{col_name}`").alias(fun(col_name)) for col_name in df.columns]
        return df.select(*cols)

    return _

with_some_columns_renamed(fun, change_col_name)

Function that takes a Callable[[str], str] and a Callable[[str], str] and returns a Callable[[DataFrame], DataFrame].

Which in turn takes a DataFrame and returns a DataFrame with some of its columns renamed.

Parameters:

Name Type Description Default
fun Callable[[str], str]

A function that takes a column name as a string and returns a new name as a string.

required
change_col_name Callable[[str], str]

A function that takes a column name as a string and returns a boolean.

required

Returns:

Type Description
`Callable[[DataFrame], DataFrame]`

A Callable[[DataFrame], DataFrame], which takes a DataFrame and returns a DataFrame with some of its columns renamed.

Source code in quinn/transformations.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def with_some_columns_renamed(
    fun: Callable[[str], str],
    change_col_name: Callable[[str], str],
) -> Callable[[DataFrame], DataFrame]:
    """Function that takes a `Callable[[str], str]` and a `Callable[[str], str]` and returns a `Callable[[DataFrame], DataFrame]`.

    Which in turn takes a `DataFrame` and returns a `DataFrame` with some of its columns renamed.

    :param fun: A function that takes a column name as a string and returns a
    new name as a string.
    :type fun: `Callable[[str], str]`
    :param change_col_name: A function that takes a column name as a string and
    returns a boolean.
    :type change_col_name: `Callable[[str], str]`
    :return: A `Callable[[DataFrame], DataFrame]`, which takes a
    `DataFrame` and returns a `DataFrame` with some of its columns renamed.
    :rtype: `Callable[[DataFrame], DataFrame]`
    """

    def _(df: DataFrame) -> DataFrame:
        cols = [
            F.col(f"`{col_name}`").alias(fun(col_name))
            if change_col_name(col_name)
            else F.col(f"`{col_name}`")
            for col_name in df.columns
        ]
        return df.select(*cols)

    return _