Skip to content

Schema helpers

complex_fields(schema)

Returns a dictionary of complex field names and their data types from the input DataFrame's schema.

Parameters:

Name Type Description Default
df DataFrame

The input PySpark DataFrame.

required

Returns:

Type Description
Dict[str, object]

A dictionary with complex field names as keys and their respective data types as values.

Source code in quinn/schema_helpers.py
157
158
159
160
161
162
163
164
165
166
167
168
169
def complex_fields(schema: T.StructType) -> dict[str, object]:
    """Returns a dictionary of complex field names and their data types from the input DataFrame's schema.

    :param df: The input PySpark DataFrame.
    :type df: DataFrame
    :return: A dictionary with complex field names as keys and their respective data types as values.
    :rtype: Dict[str, object]
    """
    return {
        field.name: field.dataType
        for field in schema.fields
        if isinstance(field.dataType, (T.ArrayType, T.StructType, T.MapType))
    }

print_schema_as_code(dtype)

Represent DataType (including StructType) as valid Python code.

Parameters:

Name Type Description Default
dtype T.DataType

The input DataType or Schema object

required

Returns:

Type Description
str

A valid python code which generate the same schema.

Source code in quinn/schema_helpers.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def print_schema_as_code(dtype: T.DataType) -> str:
    """Represent DataType (including StructType) as valid Python code.

    :param dtype: The input DataType or Schema object
    :type dtype: pyspark.sql.types.DataType
    :return: A valid python code which generate the same schema.
    :rtype: str
    """
    res = []
    if isinstance(dtype, T.StructType):
        res.append("StructType(\n\tfields=[")
        for field in dtype.fields:
            for line in _repr_column(field).split("\n"):
                res.append("\n\t\t")
                res.append(line)
            res.append(",")
        res.append("\n\t]\n)")

    elif isinstance(dtype, T.ArrayType):
        res.append("ArrayType(")
        res.append(print_schema_as_code(dtype.elementType))
        res.append(")")

    elif isinstance(dtype, T.MapType):
        res.append("MapType(")
        res.append(f"\n\t{print_schema_as_code(dtype.keyType)},")
        for line in print_schema_as_code(dtype.valueType).split("\n"):
            res.append("\n\t")
            res.append(line)
        res.append(",")
        res.append(f"\n\t{dtype.valueContainsNull},")
        res.append("\n)")

    elif isinstance(dtype, T.DecimalType):
        res.append(f"DecimalType({dtype.precision}, {dtype.scale})")

    elif str(dtype).endswith("()"):
        # PySpark 3.3+
        res.append(str(dtype))
    else:
        res.append(f"{dtype}()")

    return "".join(res)

schema_from_csv(spark, file_path)

Return a StructType from a CSV file containing schema configuration.

Parameters:

Name Type Description Default
spark SparkSession

The SparkSession object

required
file_path str

The path to the CSV file containing the schema configuration

required

Returns:

Type Description
pyspark.sql.types.StructType

A StructType object representing the schema configuration

Raises:

Type Description
ValueError

If the CSV file does not contain the expected columns: name, type, nullable, description

Source code in quinn/schema_helpers.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def schema_from_csv(spark: SparkSession, file_path: str) -> T.StructType:  # noqa: C901
    """Return a StructType from a CSV file containing schema configuration.

    :param spark: The SparkSession object
    :type spark: pyspark.sql.session.SparkSession

    :param file_path: The path to the CSV file containing the schema configuration
    :type file_path: str

    :raises ValueError: If the CSV file does not contain the expected columns: name, type, nullable, description

    :return: A StructType object representing the schema configuration
    :rtype: pyspark.sql.types.StructType
    """

    def _validate_json(metadata: Optional[str]) -> dict:
        if metadata is None:
            return {}

        try:
            metadata_dict = json.loads(metadata)

        except json.JSONDecodeError as exc:
            msg = f"Invalid JSON: {metadata}"
            raise ValueError(msg) from exc

        return metadata_dict

    def _lookup_type(type_str: str) -> T.DataType:
        type_lookup = {
            "string": T.StringType(),
            "int": T.IntegerType(),
            "float": T.FloatType(),
            "double": T.DoubleType(),
            "boolean": T.BooleanType(),
            "bool": T.BooleanType(),
            "timestamp": T.TimestampType(),
            "date": T.DateType(),
            "binary": T.BinaryType(),
        }

        if type_str not in type_lookup:
            msg = f"Invalid type: {type_str}. Expecting one of: {type_lookup.keys()}"
            raise ValueError(msg)

        return type_lookup[type_str]

    def _convert_nullable(null_str: str) -> bool:
        if null_str is None:
            return True

        parsed_val = null_str.lower()
        if parsed_val not in ["true", "false"]:
            msg = f"Invalid nullable value: {null_str}. Expecting True or False."
            raise ValueError(msg)

        return parsed_val == "true"

    schema_df = spark.read.csv(file_path, header=True)
    possible_columns = ["name", "type", "nullable", "metadata"]
    num_cols = len(schema_df.columns)
    expected_columns = possible_columns[0:num_cols]

    # ensure that csv contains the expected columns: name, type, nullable, description
    if schema_df.columns != expected_columns:
        msg = f"CSV must contain columns in this order: {expected_columns}"
        raise ValueError(msg)

    # create a StructType per field
    fields = []
    for row in schema_df.collect():
        field = T.StructField(
            name=row["name"],
            dataType=_lookup_type(row["type"]),
            nullable=_convert_nullable(row["nullable"]) if "nullable" in row else True,
            metadata=_validate_json(row["metadata"] if "metadata" in row else None),
        )
        fields.append(field)

    return T.StructType(fields=fields)