Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 56 additions & 30 deletions sqlglot/typing/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,65 @@
from sqlglot.typing import ExprMetadataType


def _annotate_by_similar_args(
self: TypeAnnotator, expression: E, *args: str, target_type: exp.DataType | exp.DType
) -> E:
def _common_array_element_type(types: list[exp.DataType]) -> exp.DataType | exp.DType:
"""
Infers the type of the expression according to the following rules:
- If all args are of the same type OR any arg is of target_type, the expr is inferred as such
- If any arg is of UNKNOWN type and none of target_type, the expr is inferred as UNKNOWN
Recursively narrows a list of CONCAT-arg DataTypes to their common type.

- Returns UNKNOWN for incompatible types: scalar mismatches (INT + DATE),
ARRAY mixed with non-ARRAY (e.g. CONCAT(ARRAY<INT>, INT)), and unmatched
nesting depths (which yield ARRAY<UNKNOWN> at the appropriate level).
UNKNOWN means "no common type", which the caller relies on.
- Return type is exp.DataType | exp.DType: bare DType for simple cases,
DataType for the recursive ARRAY case where nesting is required.
"""
expressions: list[exp.Expr] = []
for arg in args:
arg_expr = expression.args.get(arg)
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
normalized = [
exp.DataType(this=exp.DType.TEXT) if t.this in exp.DataType.TEXT_TYPES else t for t in types
]
if len({t.sql() for t in normalized}) == 1:
return normalized[0]
if all(t.this == exp.DType.ARRAY for t in normalized):
elem_types = [
t.expressions[0] if t.expressions else exp.DataType(this=exp.DType.UNKNOWN)
for t in normalized
]
common_elem = _common_array_element_type(elem_types)
elem_dt = (
common_elem if isinstance(common_elem, exp.DataType) else exp.DataType(this=common_elem)
)
return exp.DataType(this=exp.DType.ARRAY, expressions=[elem_dt], nested=True)
if any(t.this == exp.DType.TEXT for t in normalized):
return exp.DType.TEXT
return exp.DType.UNKNOWN

last_datatype = None

has_unknown = False
for expr in expressions:
if expr.is_type(exp.DType.UNKNOWN):
has_unknown = True
elif expr.is_type(target_type):
has_unknown = False
last_datatype = target_type
break
else:
last_datatype = expr.type
def _annotate_by_similar_args(self: TypeAnnotator, expression: E, *arg_keys: str) -> E:
"""
Type inference for CONCAT-family expressions (CONCAT, LPAD, RPAD).

- TEXT-before-UNKNOWN is load-bearing: a known text arg forces a text
result, since the query either coerces the unknown to string or fails
entirely — no valid execution produces a non-text result.
- TEXT_TYPES on input narrows to DType.TEXT on output: CONCAT/LPAD
accept any TEXT_TYPES member (VARCHAR/CHAR/NCHAR/NVARCHAR/NAME) as
input, but Spark always emits DType.TEXT.
"""
arg_exprs: list[exp.Expression] = []
for key in arg_keys:
arg_exprs.extend(e for e in ensure_list(expression.args.get(key)) if e)

self._set_type(expression, exp.DType.UNKNOWN if has_unknown else last_datatype)
result: exp.DataType | exp.DType
if any(e.is_type(*exp.DataType.TEXT_TYPES) for e in arg_exprs):
result = exp.DType.TEXT
elif any(e.is_type(exp.DType.UNKNOWN) for e in arg_exprs):
result = exp.DType.UNKNOWN
elif all(e.is_type(exp.DType.BINARY) for e in arg_exprs):
result = exp.DType.BINARY
elif any(e.is_type(exp.DType.ARRAY) for e in arg_exprs):
result = _common_array_element_type([e.type for e in arg_exprs])
else:
result = exp.DType.TEXT

self._set_type(expression, result)
return expression


Expand Down Expand Up @@ -72,15 +104,9 @@ def _annotate_by_similar_args(
)
},
exp.AtTimeZone: {"returns": exp.DType.TIMESTAMP},
exp.Concat: {
"annotator": lambda self, e: _annotate_by_similar_args(
self, e, "expressions", target_type=exp.DType.TEXT
)
},
exp.Concat: {"annotator": lambda self, e: _annotate_by_similar_args(self, e, "expressions")},
exp.NextDay: {"returns": exp.DType.DATE},
exp.Pad: {
"annotator": lambda self, e: _annotate_by_similar_args(
self, e, "this", "fill_pattern", target_type=exp.DType.TEXT
)
"annotator": lambda self, e: _annotate_by_similar_args(self, e, "this", "fill_pattern")
},
}
28 changes: 28 additions & 0 deletions tests/fixtures/optimizer/annotate_functions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,34 @@ UNKNOWN;
CONCAT(unknown, unknown);
UNKNOWN;

# dialect: spark2, spark, databricks
CONCAT('x', tbl.date_col);
STRING;

# dialect: spark2, spark, databricks
CONCAT(tbl.date_col, tbl.date_col);
STRING;

# dialect: spark2, spark, databricks
CONCAT('x', tbl.bin_col);
STRING;

# dialect: spark2, spark, databricks
CONCAT(array('a', 'b'), array(1, 2));
ARRAY<STRING>;

# dialect: spark2, spark, databricks
CONCAT(array(array('a')), array(array(1)));
ARRAY<ARRAY<STRING>>;

# dialect: spark2, spark, databricks
CONCAT(tbl.date_col, tbl.int_col);
STRING;

# dialect: spark2, spark, databricks
LPAD('x', 10, tbl.date_col);
STRING;

# dialect: spark2, spark, databricks
LPAD(tbl.bin_col, 1, tbl.bin_col);
BINARY;
Expand Down
Loading