Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebase #7223: improve contracts error message #7319

Merged
merged 11 commits into from
Apr 11, 2023
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230325-192830.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Added prettier printing to ContractError class
time: 2023-03-25T19:28:30.171461-07:00
custom:
Author: kentkr
Issue: "7209"
65 changes: 59 additions & 6 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import builtins
import json
import re
import io
import agate
from typing import Any, Dict, List, Mapping, Optional, Union

from dbt.dataclass_schema import ValidationError
Expand Down Expand Up @@ -2145,14 +2147,65 @@ def __init__(self, yaml_columns, sql_columns):
self.sql_columns = sql_columns
super().__init__(msg=self.get_message())

def get_message(self) -> str:
def get_mismatches(self) -> agate.Table:
# avoid a circular import
from dbt.clients.agate_helper import table_from_data_flat

column_names = ["column_name", "definition_type", "contract_type", "mismatch_reason"]
# list of mismatches
mismatches: List[Dict[str, str]] = []
# track sql cols so we don't need another for loop later
sql_col_set = set()
# for each sql col list
for sql_col in self.sql_columns:
# add sql col to set
sql_col_set.add(sql_col["name"])
# for each yaml col list
for i, yaml_col in enumerate(self.yaml_columns):
# if name matches
if sql_col["name"] == yaml_col["name"]:
# if type matches
if sql_col["data_type"] == yaml_col["data_type"]:
# its a perfect match! don't include in mismatch table
break
else:
# same name, diff type
row = [
sql_col["name"],
sql_col["data_type"],
yaml_col["data_type"],
"data type mismatch",
]
mismatches += [dict(zip(column_names, row))]
break
# if last loop, then no name match
if i == len(self.yaml_columns) - 1:
row = [sql_col["name"], sql_col["data_type"], "", "missing in contract"]
mismatches += [dict(zip(column_names, row))]

# now add all yaml cols without a match
for yaml_col in self.yaml_columns:
if yaml_col["name"] not in sql_col_set:
row = [yaml_col["name"], "", yaml_col["data_type"], "missing in definition"]
mismatches += [dict(zip(column_names, row))]

mismatches_sorted = sorted(mismatches, key=lambda d: d["column_name"])
Copy link
Contributor

@MichelleArk MichelleArk Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯✨

return table_from_data_flat(mismatches_sorted, column_names)

def get_message(self) -> str:
table: agate.Table = self.get_mismatches()
# Hack to get Agate table output as string
output = io.StringIO()
table.print_table(output=output, max_rows=None, max_column_width=50) # type: ignore
mismatches = output.getvalue()

msg = (
"Contracts are enabled for this model. "
"Please ensure the name, data_type, and number of columns in your `yml` file "
"match the columns in your SQL file.\n"
f"Schema File Columns: {self.yaml_columns}\n"
f"SQL File Columns: {self.sql_columns}"
"This model has an enforced contract that failed.\n"
"Please ensure the name, data_type, and number of columns in your contract "
"match the columns in your model's definition.\n\n"
f"{mismatches}"
)

return msg


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@

{#-- create dictionaries with name and formatted data type and strings for exception #}
{%- set sql_columns = format_columns(sql_file_provided_columns) -%}
{%- set string_sql_columns = stringify_formatted_columns(sql_columns) -%}
{%- set yaml_columns = format_columns(schema_file_provided_columns) -%}
{%- set string_yaml_columns = stringify_formatted_columns(yaml_columns) -%}

{%- if sql_columns|length != yaml_columns|length -%}
{%- do exceptions.raise_contract_error(string_yaml_columns, string_sql_columns) -%}
{%- do exceptions.raise_contract_error(yaml_columns, sql_columns) -%}
{%- endif -%}

{%- for sql_col in sql_columns -%}
Expand All @@ -56,11 +54,11 @@
{%- endfor -%}
{%- if not yaml_col -%}
{#-- Column with name not found in yaml #}
{%- do exceptions.raise_contract_error(string_yaml_columns, string_sql_columns) -%}
{%- do exceptions.raise_contract_error(yaml_columns, sql_columns) -%}
{%- endif -%}
{%- if sql_col['formatted'] != yaml_col[0]['formatted'] -%}
{#-- Column data types don't match #}
{%- do exceptions.raise_contract_error(string_yaml_columns, string_sql_columns) -%}
{%- do exceptions.raise_contract_error(yaml_columns, sql_columns) -%}
{%- endif -%}
{%- endfor -%}

Expand All @@ -70,19 +68,13 @@
{% set formatted_columns = [] %}
{% for column in columns %}
{%- set formatted_column = adapter.dispatch('format_column', 'dbt')(column) -%}
{%- do formatted_columns.append({'name': column.name, 'formatted': formatted_column}) -%}
{%- do formatted_columns.append(formatted_column) -%}
{% endfor %}
{{ return(formatted_columns) }}
{% endmacro %}

{% macro stringify_formatted_columns(formatted_columns) %}
{% set column_strings = [] %}
{% for column in formatted_columns %}
{% do column_strings.append(column['formatted']) %}
{% endfor %}
{{ return(column_strings|join(', ')) }}
{% endmacro %}

{% macro default__format_column(column) -%}
{{ return(column.column.lower() ~ " " ~ column.dtype) }}
{% set data_type = column.dtype %}
{% set formatted = column.column.lower() ~ " " ~ data_type %}
{{ return({'name': column.name, 'data_type': data_type, 'formatted': formatted}) }}
{%- endmacro -%}
32 changes: 9 additions & 23 deletions tests/adapter/dbt/tests/adapter/constraints/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,8 @@ def test__constraints_wrong_column_names(self, project, string_type, int_type):

assert contract_actual_config.enforced is True

expected_compile_error = "Please ensure the name, data_type, and number of columns in your `yml` file match the columns in your SQL file."
expected_schema_file_columns = (
f"Schema File Columns: id {int_type}, color {string_type}, date_day {string_type}"
)
expected_sql_file_columns = (
f"SQL File Columns: color {string_type}, error {int_type}, date_day {string_type}"
)

assert expected_compile_error in log_output
assert expected_schema_file_columns in log_output
assert expected_sql_file_columns in log_output
expected = ["id", "error", "missing in definition", "missing in contract"]
assert all([(exp in log_output or exp.upper() in log_output) for exp in expected])

def test__constraints_wrong_column_data_types(
self, project, string_type, int_type, schema_int_type, data_types
Expand Down Expand Up @@ -128,18 +119,13 @@ def test__constraints_wrong_column_data_types(
contract_actual_config = my_model_config.contract

assert contract_actual_config.enforced is True

expected_compile_error = "Please ensure the name, data_type, and number of columns in your `yml` file match the columns in your SQL file."
expected_sql_file_columns = (
f"SQL File Columns: wrong_data_type_column_name {error_data_type}"
)
expected_schema_file_columns = (
f"Schema File Columns: wrong_data_type_column_name {wrong_schema_error_data_type}"
)

assert expected_compile_error in log_output
assert expected_schema_file_columns in log_output
assert expected_sql_file_columns in log_output
expected = [
"wrong_data_type_column_name",
error_data_type,
wrong_schema_error_data_type,
"data type mismatch",
]
assert all([(exp in log_output or exp.upper() in log_output) for exp in expected])

def test__constraints_correct_column_data_types(self, project, data_types):
for (sql_column_value, schema_data_type, _) in data_types:
Expand Down