Coverage for dj/sql/functions.py: 100%
393 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
1# pylint: disable=too-many-lines
2# mypy: ignore-errors
4"""
5SQL functions for type inference.
7This file holds all the functions that we want to support in the SQL used to define
8nodes. The functions are used to infer types.
10Spark function reference
11https://github.com/apache/spark/tree/74cddcfda3ac4779de80696cdae2ba64d53fc635/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions
13Java strictmath reference
14https://docs.oracle.com/javase/8/docs/api/java/lang/StrictMath.html
16Databricks reference:
17https://docs.databricks.com/sql/language-manual/sql-ref-functions-builtin-alpha.html
18"""
20import inspect
21import re
22from itertools import zip_longest
24# pylint: disable=unused-argument, missing-function-docstring, arguments-differ, too-many-return-statements
25from typing import (
26 TYPE_CHECKING,
27 Callable,
28 ClassVar,
29 Dict,
30 List,
31 Optional,
32 Tuple,
33 Type,
34 Union,
35 get_origin,
36)
38import dj.sql.parsing.types as ct
39from dj.errors import (
40 DJError,
41 DJInvalidInputException,
42 DJNotImplementedException,
43 ErrorCode,
44)
45from dj.sql.parsing.backends.exceptions import DJParseException
47if TYPE_CHECKING:
48 from dj.sql.parsing.ast import Expression
51def compare_registers(types, register) -> bool:
52 """
53 Comparing registers
54 """
55 for ((type_a, register_a), (type_b, register_b)) in zip_longest(
56 types,
57 register,
58 fillvalue=(-1, None),
59 ):
60 if type_b == -1 and register_b is None:
61 if register[-1][0] == -1: # args
62 register_b = register[-1][1]
63 else:
64 return False # pragma: no cover
65 if type_a == -1:
66 register_a = type(register_a)
67 if not issubclass(register_a, register_b): # type: ignore
68 return False
69 return True
72class DispatchMeta(type):
73 """
74 Dispatch abstract class for function registry
75 """
77 def __getattribute__(cls, func_name): # pylint: disable=redefined-outer-name
78 if func_name in type.__getattribute__(cls, "registry").get(cls, {}):
80 def dynamic_dispatch(*args: "Expression"):
81 return cls.dispatch(func_name, *args)(*args)
83 return dynamic_dispatch
84 return type.__getattribute__(cls, func_name)
87class Dispatch(metaclass=DispatchMeta):
88 """
89 Function registry
90 """
92 registry: ClassVar[Dict[str, Dict[Tuple[Tuple[int, Type]], Callable]]] = {}
94 @classmethod
95 def register(cls, func): # pylint: disable=redefined-outer-name
96 func_name = func.__name__
97 params = inspect.signature(func).parameters
98 spread_types = [[]]
99 cls.registry[cls] = cls.registry.get(cls) or {}
100 cls.registry[cls][func_name] = cls.registry[cls].get(func_name) or {}
101 for i, (key, value) in enumerate(params.items()):
102 name = str(value).split(":", maxsplit=1)[0]
103 if name.startswith("**"):
104 raise ValueError(
105 "kwargs are not supported in dispatch.",
106 ) # pragma: no cover
107 if name.startswith("*"):
108 i = -1
109 type_ = params[key].annotation
110 if type_ == inspect.Parameter.empty:
111 raise ValueError( # pragma: no cover
112 "All arguments must have a type annotation.",
113 )
114 inner_types = [type_]
115 if get_origin(type_) == Union:
116 inner_types = type_.__args__
117 for _ in inner_types:
118 spread_types += spread_types[:]
119 temp = []
120 for type_ in inner_types:
121 for types in spread_types:
122 temp.append(types[:])
123 temp[-1].append((i, type_))
124 spread_types = temp
125 for types in spread_types:
126 cls.registry[cls][func_name][tuple(types)] = func # type: ignore
128 @classmethod
129 def dispatch( # pylint: disable=redefined-outer-name
130 cls, func_name, *args: "Expression"
131 ):
132 type_registry = cls.registry[cls].get(func_name) # type: ignore
133 if not type_registry:
134 raise ValueError(
135 f"No function registered on {cls.__name__}`{func_name}`.",
136 ) # pragma: no cover
138 type_list = []
139 for i, arg in enumerate(args):
140 type_list.append((i, type(arg.type) if hasattr(arg, "type") else type(arg)))
142 types = tuple(type_list)
144 if types in type_registry: # type: ignore
145 return type_registry[types] # type: ignore
147 for register, func in type_registry.items(): # type: ignore
148 if compare_registers(types, register):
149 return func
151 raise TypeError(
152 f"`{cls.__name__}.{func_name}` got an invalid "
153 "combination of types: "
154 f'{", ".join(str(t[1].__name__) for t in types)}',
155 )
158class Function(Dispatch): # pylint: disable=too-few-public-methods
159 """
160 A DJ function.
161 """
163 is_aggregation: ClassVar[bool] = False
165 @staticmethod
166 def infer_type(*args) -> ct.ColumnType:
167 raise NotImplementedError()
170class TableFunction(Dispatch): # pylint: disable=too-few-public-methods
171 """
172 A DJ table-valued function.
173 """
175 @staticmethod
176 def infer_type(*args) -> List[ct.ColumnType]:
177 raise NotImplementedError()
180class Avg(Function): # pylint: disable=abstract-method
181 """
182 Computes the average of the input column or expression.
183 """
185 is_aggregation = True
188@Avg.register
189def infer_type(
190 arg: ct.DecimalType,
191) -> ct.DecimalType: # noqa: F811 # pylint: disable=function-redefined
192 type_ = arg.type
193 return ct.DecimalType(type_.precision + 4, type_.scale + 4)
196@Avg.register # type: ignore
197def infer_type( # noqa: F811 # pylint: disable=function-redefined
198 arg: ct.IntervalTypeBase,
199) -> ct.IntervalTypeBase:
200 return type(arg.type)()
203@Avg.register # type: ignore
204def infer_type( # noqa: F811 # pylint: disable=function-redefined
205 arg: ct.NumberType,
206) -> ct.DoubleType:
207 return ct.DoubleType()
210class Min(Function): # pylint: disable=abstract-method
211 """
212 Computes the minimum value of the input column or expression.
213 """
215 is_aggregation = True
218@Min.register # type: ignore
219def infer_type( # noqa: F811 # pylint: disable=function-redefined
220 arg: ct.NumberType,
221) -> ct.NumberType:
222 return arg.type
225class Max(Function): # pylint: disable=abstract-method
226 """
227 Computes the maximum value of the input column or expression.
228 """
230 is_aggregation = True
233@Max.register # type: ignore
234def infer_type( # noqa: F811 # pylint: disable=function-redefined
235 arg: ct.NumberType,
236) -> ct.NumberType:
237 return arg.type
240@Max.register # type: ignore
241def infer_type( # noqa: F811 # pylint: disable=function-redefined
242 arg: ct.StringType,
243) -> ct.StringType:
244 return arg.type
247class Sum(Function): # pylint: disable=abstract-method
248 """
249 Computes the sum of the input column or expression.
250 """
252 is_aggregation = True
255@Sum.register # type: ignore
256def infer_type( # noqa: F811 # pylint: disable=function-redefined
257 arg: ct.IntegerBase,
258) -> ct.BigIntType:
259 return ct.BigIntType()
262@Sum.register # type: ignore
263def infer_type( # noqa: F811 # pylint: disable=function-redefined
264 arg: ct.DecimalType,
265) -> ct.DecimalType:
266 precision = arg.type.precision
267 scale = arg.type.scale
268 return ct.DecimalType(precision + min(10, 31 - precision), scale)
271@Sum.register # type: ignore
272def infer_type( # noqa: F811 # pylint: disable=function-redefined
273 arg: Union[ct.NumberType, ct.IntervalTypeBase],
274) -> ct.DoubleType:
275 return ct.DoubleType()
278class Ceil(Function): # pylint: disable=abstract-method
279 """
280 Computes the smallest integer greater than or equal to the input value.
281 """
284@Ceil.register
285def infer_type( # noqa: F811 # pylint: disable=function-redefined
286 args: ct.NumberType,
287 _target_scale: ct.IntegerType,
288) -> ct.DecimalType:
289 target_scale = _target_scale.value
290 if isinstance(args.type, ct.DecimalType):
291 precision = max(args.type.precision - args.type.scale + 1, -target_scale + 1)
292 scale = min(args.type.scale, max(0, target_scale))
293 return ct.DecimalType(precision, scale)
294 if args.type == ct.TinyIntType():
295 precision = max(3, -target_scale + 1)
296 return ct.DecimalType(precision, 0)
297 if args.type == ct.SmallIntType():
298 precision = max(5, -target_scale + 1)
299 return ct.DecimalType(precision, 0)
300 if args.type == ct.IntegerType():
301 precision = max(10, -target_scale + 1)
302 return ct.DecimalType(precision, 0)
303 if args.type == ct.BigIntType():
304 precision = max(20, -target_scale + 1)
305 return ct.DecimalType(precision, 0)
306 if args.type == ct.FloatType():
307 precision = max(14, -target_scale + 1)
308 scale = min(7, max(0, target_scale))
309 return ct.DecimalType(precision, scale)
310 if args.type == ct.DoubleType():
311 precision = max(30, -target_scale + 1)
312 scale = min(15, max(0, target_scale))
313 return ct.DecimalType(precision, scale)
315 raise DJParseException(
316 f"Unhandled numeric type in Ceil `{args.type}`",
317 ) # pragma: no cover
320@Ceil.register
321def infer_type( # noqa: F811 # pylint: disable=function-redefined
322 args: ct.DecimalType,
323) -> ct.DecimalType:
324 return ct.DecimalType(args.type.precision - args.type.scale + 1, 0)
327@Ceil.register
328def infer_type( # noqa: F811 # pylint: disable=function-redefined
329 args: ct.NumberType,
330) -> ct.BigIntType:
331 return ct.BigIntType()
334class Count(Function): # pylint: disable=abstract-method
335 """
336 Counts the number of non-null values in the input column or expression.
337 """
339 is_aggregation = True
342@Count.register # type: ignore
343def infer_type( # noqa: F811 # pylint: disable=function-redefined
344 *args: ct.ColumnType,
345) -> ct.BigIntType:
346 return ct.BigIntType()
349class Coalesce(Function): # pylint: disable=abstract-method
350 """
351 Computes the average of the input column or expression.
352 """
354 is_aggregation = False
357@Coalesce.register # type: ignore
358def infer_type( # noqa: F811 # pylint: disable=function-redefined
359 *args: ct.ColumnType,
360) -> ct.ColumnType:
361 if not args: # pragma: no cover
362 raise DJInvalidInputException(
363 message="Wrong number of arguments to function",
364 errors=[
365 DJError(
366 code=ErrorCode.INVALID_ARGUMENTS_TO_FUNCTION,
367 message="You need to pass at least one argument to `COALESCE`.",
368 ),
369 ],
370 )
371 for arg in args:
372 if arg.type != ct.NullType():
373 return arg.type
374 return ct.NullType()
377class CurrentDate(Function): # pylint: disable=abstract-method
378 """
379 Returns the current date.
380 """
383@CurrentDate.register # type: ignore
384def infer_type() -> ct.DateType: # noqa: F811 # pylint: disable=function-redefined
385 return ct.DateType()
388class CurrentDatetime(Function): # pylint: disable=abstract-method
389 """
390 Returns the current date and time.
391 """
394@CurrentDatetime.register # type: ignore
395def infer_type() -> ct.TimestampType: # noqa: F811 # pylint: disable=function-redefined
396 return ct.TimestampType()
399class CurrentTime(Function): # pylint: disable=abstract-method
400 """
401 Returns the current time.
402 """
405@CurrentTime.register # type: ignore
406def infer_type() -> ct.TimeType: # noqa: F811 # pylint: disable=function-redefined
407 return ct.TimeType()
410class CurrentTimestamp(Function): # pylint: disable=abstract-method
411 """
412 Returns the current timestamp.
413 """
416@CurrentTimestamp.register # type: ignore
417def infer_type() -> ct.TimestampType: # noqa: F811 # pylint: disable=function-redefined
418 return ct.TimestampType()
421class Now(Function): # pylint: disable=abstract-method
422 """
423 Returns the current timestamp.
424 """
427@Now.register # type: ignore
428def infer_type() -> ct.TimestamptzType: # noqa: F811 # pylint: disable=function-redefined
429 return ct.TimestamptzType()
432class DateAdd(Function): # pylint: disable=abstract-method
433 """
434 Adds a specified number of days to a date.
435 """
438@DateAdd.register # type: ignore
439def infer_type( # noqa: F811 # pylint: disable=function-redefined
440 start_date: ct.DateType,
441 days: ct.IntegerBase,
442) -> ct.DateType:
443 return ct.DateType()
446@DateAdd.register # type: ignore
447def infer_type( # noqa: F811 # pylint: disable=function-redefined
448 start_date: ct.StringType,
449 days: ct.IntegerBase,
450) -> ct.DateType:
451 return ct.DateType()
454class DateSub(Function): # pylint: disable=abstract-method
455 """
456 Subtracts a specified number of days from a date.
457 """
460@DateSub.register # type: ignore
461def infer_type( # noqa: F811 # pylint: disable=function-redefined
462 start_date: ct.DateType,
463 days: ct.IntegerBase,
464) -> ct.DateType:
465 return ct.DateType()
468@DateSub.register # type: ignore
469def infer_type( # noqa: F811 # pylint: disable=function-redefined
470 start_date: ct.StringType,
471 days: ct.IntegerBase,
472) -> ct.DateType:
473 return ct.DateType()
476class If(Function): # pylint: disable=abstract-method
477 """
478 If statement
480 if(condition, result, else_result): if condition evaluates to true,
481 then returns result; otherwise returns else_result.
482 """
485@If.register # type: ignore
486def infer_type( # noqa: F811 # pylint: disable=function-redefined
487 cond: ct.BooleanType,
488 then: ct.ColumnType,
489 else_: ct.ColumnType,
490) -> ct.ColumnType:
491 if then.type != else_.type:
492 raise DJInvalidInputException(
493 message="The then result and else result must match in type! "
494 f"Got {then.type} and {else_.type}",
495 )
497 return then.type
500class DateDiff(Function): # pylint: disable=abstract-method
501 """
502 Computes the difference in days between two dates.
503 """
506@DateDiff.register # type: ignore
507def infer_type( # noqa: F811 # pylint: disable=function-redefined
508 start_date: ct.DateType,
509 end_date: ct.DateType,
510) -> ct.IntegerType:
511 return ct.IntegerType()
514@DateDiff.register # type: ignore
515def infer_type( # noqa: F811 # pylint: disable=function-redefined
516 start_date: ct.StringType,
517 end_date: ct.StringType,
518) -> ct.IntegerType:
519 return ct.IntegerType()
522class Extract(Function):
523 """
524 Returns a specified component of a timestamp, such as year, month or day.
525 """
527 @staticmethod
528 def infer_type( # type: ignore
529 field: "Expression",
530 source: "Expression",
531 ) -> Union[ct.DecimalType, ct.IntegerType]:
532 if str(field.name) == "SECOND": # type: ignore
533 return ct.DecimalType(8, 6)
534 return ct.IntegerType()
537class ToDate(Function): # pragma: no cover # pylint: disable=abstract-method
538 """
539 Converts a date string to a date value.
540 """
543@ToDate.register # type: ignore
544def infer_type( # noqa: F811 # pylint: disable=function-redefined
545 expr: ct.StringType,
546 fmt: Optional[ct.StringType] = None,
547) -> ct.DateType:
548 return ct.DateType()
551class Day(Function): # pylint: disable=abstract-method
552 """
553 Returns the day of the month for a specified date.
554 """
557@Day.register # type: ignore
558def infer_type( # noqa: F811 # pylint: disable=function-redefined
559 arg: Union[ct.StringType, ct.DateType, ct.TimestampType],
560) -> ct.IntegerType: # type: ignore
561 return ct.IntegerType()
564class Exp(Function): # pylint: disable=abstract-method
565 """
566 Returns e to the power of expr.
567 """
570@Exp.register # type: ignore
571def infer_type( # noqa: F811 # pylint: disable=function-redefined
572 args: ct.ColumnType,
573) -> ct.DoubleType:
574 return ct.DoubleType()
577class Floor(Function): # pylint: disable=abstract-method
578 """
579 Returns the largest integer less than or equal to a specified number.
580 """
583@Floor.register # type: ignore
584def infer_type( # noqa: F811 # pylint: disable=function-redefined
585 args: ct.DecimalType,
586) -> ct.DecimalType:
587 return ct.DecimalType(args.type.precision - args.type.scale + 1, 0)
590@Floor.register # type: ignore
591def infer_type( # noqa: F811 # pylint: disable=function-redefined
592 args: ct.NumberType,
593) -> ct.BigIntType:
594 return ct.BigIntType()
597@Floor.register # type: ignore
598def infer_type( # noqa: F811 # pylint: disable=function-redefined
599 args: ct.NumberType,
600 _target_scale: ct.IntegerType,
601) -> ct.DecimalType:
602 target_scale = _target_scale.value
603 if isinstance(args.type, ct.DecimalType): # pylint: disable=R1705
604 precision = max(args.type.precision - args.type.scale + 1, -target_scale + 1)
605 scale = min(args.type.scale, max(0, target_scale))
606 return ct.DecimalType(precision, scale)
607 if args.type == ct.TinyIntType():
608 precision = max(3, -target_scale + 1)
609 return ct.DecimalType(precision, 0)
610 if args.type == ct.SmallIntType():
611 precision = max(5, -target_scale + 1)
612 return ct.DecimalType(precision, 0)
613 if args.type == ct.IntegerType():
614 precision = max(10, -target_scale + 1)
615 return ct.DecimalType(precision, 0)
616 if args.type == ct.BigIntType():
617 precision = max(20, -target_scale + 1)
618 return ct.DecimalType(precision, 0)
619 if args.type == ct.FloatType():
620 precision = max(14, -target_scale + 1)
621 scale = min(7, max(0, target_scale))
622 return ct.DecimalType(precision, scale)
623 if args.type == ct.DoubleType():
624 precision = max(30, -target_scale + 1)
625 scale = min(15, max(0, target_scale))
626 return ct.DecimalType(precision, scale)
628 raise DJParseException(
629 f"Unhandled numeric type in Floor `{args.type}`",
630 ) # pragma: no cover
633class IfNull(Function):
634 """
635 Returns the second expression if the first is null, else returns the first expression.
636 """
638 @staticmethod
639 def infer_type(*args: "Expression") -> ct.ColumnType: # type: ignore
640 return ( # type: ignore
641 args[0].type if args[1].type == ct.NullType() else args[1].type
642 )
645class Length(Function): # pylint: disable=abstract-method
646 """
647 Returns the length of a string.
648 """
651@Length.register # type: ignore
652def infer_type( # noqa: F811 # pylint: disable=function-redefined
653 arg: ct.StringType,
654) -> ct.IntegerType:
655 return ct.IntegerType()
658class Levenshtein(Function): # pylint: disable=abstract-method
659 """
660 Returns the Levenshtein distance between two strings.
661 """
664@Levenshtein.register # type: ignore
665def infer_type( # noqa: F811 # pylint: disable=function-redefined
666 string1: ct.StringType,
667 string2: ct.StringType,
668) -> ct.IntegerType:
669 return ct.IntegerType()
672class Ln(Function): # pylint: disable=abstract-method
673 """
674 Returns the natural logarithm of a number.
675 """
678@Ln.register # type: ignore
679def infer_type( # noqa: F811 # pylint: disable=function-redefined
680 args: ct.ColumnType,
681) -> ct.DoubleType:
682 return ct.DoubleType()
685class Log(Function): # pylint: disable=abstract-method
686 """
687 Returns the logarithm of a number with the specified base.
688 """
691@Log.register # type: ignore
692def infer_type( # noqa: F811 # pylint: disable=function-redefined
693 base: ct.ColumnType,
694 expr: ct.ColumnType,
695) -> ct.DoubleType:
696 return ct.DoubleType()
699class Log2(Function): # pylint: disable=abstract-method
700 """
701 Returns the base-2 logarithm of a number.
702 """
705@Log2.register # type: ignore
706def infer_type( # noqa: F811 # pylint: disable=function-redefined
707 args: ct.ColumnType,
708) -> ct.DoubleType:
709 return ct.DoubleType()
712class Log10(Function): # pylint: disable=abstract-method
713 """
714 Returns the base-10 logarithm of a number.
715 """
718@Log10.register # type: ignore
719def infer_type( # noqa: F811 # pylint: disable=function-redefined
720 args: ct.ColumnType,
721) -> ct.DoubleType:
722 return ct.DoubleType()
725class Lower(Function):
726 """
727 Converts a string to lowercase.
728 """
730 @staticmethod
731 def infer_type(arg: "Expression") -> ct.StringType: # type: ignore
732 return ct.StringType()
735class Month(Function):
736 """
737 Extracts the month of a date or timestamp.
738 """
740 @staticmethod
741 def infer_type(arg: "Expression") -> ct.TinyIntType: # type: ignore
742 return ct.TinyIntType()
745class Pow(Function): # pylint: disable=abstract-method
746 """
747 Raises a base expression to the power of an exponent expression.
748 """
751@Pow.register # type: ignore
752def infer_type( # noqa: F811 # pylint: disable=function-redefined
753 base: ct.ColumnType,
754 power: ct.ColumnType,
755) -> ct.DoubleType:
756 return ct.DoubleType()
759class PercentRank(Function):
760 """
761 Window function: returns the relative rank (i.e. percentile) of rows within a window partition
762 """
764 is_aggregation = True
766 @staticmethod
767 def infer_type() -> ct.DoubleType:
768 return ct.DoubleType()
771class Quantile(Function): # pragma: no cover
772 """
773 Computes the quantile of a numerical column or expression.
774 """
776 is_aggregation = True
778 @staticmethod
779 def infer_type( # type: ignore
780 arg1: "Expression",
781 arg2: "Expression",
782 ) -> ct.DoubleType:
783 return ct.DoubleType()
786class ApproxQuantile(Function): # pragma: no cover
787 """
788 Computes the approximate quantile of a numerical column or expression.
789 """
791 is_aggregation = True
793 @staticmethod
794 def infer_type( # type: ignore
795 arg1: "Expression",
796 arg2: "Expression",
797 ) -> ct.DoubleType:
798 return ct.DoubleType()
801class RegexpLike(Function): # pragma: no cover
802 """
803 Matches a string column or expression against a regular expression pattern.
804 """
806 @staticmethod
807 def infer_type( # type: ignore
808 arg1: "Expression",
809 arg2: "Expression",
810 ) -> ct.BooleanType:
811 return ct.BooleanType()
814class Round(Function): # pylint: disable=abstract-method
815 """
816 Rounds a numeric column or expression to the specified number of decimal places.
817 """
820@Round.register # type: ignore
821def infer_type( # noqa: F811 # pylint: disable=function-redefined
822 child: ct.DecimalType,
823 scale: ct.IntegerBase,
824) -> ct.NumberType:
825 child_type = child.type
826 integral_least_num_digits = child_type.precision - child_type.scale + 1
827 if scale.value < 0:
828 new_precision = max(
829 integral_least_num_digits,
830 -scale.type.value + 1,
831 ) # pragma: no cover
832 return ct.DecimalType(new_precision, 0) # pragma: no cover
833 new_scale = min(child_type.scale, scale.value)
834 return ct.DecimalType(integral_least_num_digits + new_scale, new_scale)
837@Round.register
838def infer_type( # noqa: F811 # pylint: disable=function-redefined # type: ignore
839 child: ct.NumberType,
840 scale: ct.IntegerBase,
841) -> ct.NumberType:
842 return child.type
845class SafeDivide(Function): # pragma: no cover
846 """
847 Divides two numeric columns or expressions and returns NULL if the denominator is 0.
848 """
850 @staticmethod
851 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.DoubleType: # type: ignore
852 return ct.DoubleType()
855class Substring(Function):
856 """
857 Extracts a substring from a string column or expression.
858 """
860 @staticmethod
861 def infer_type( # type: ignore
862 arg1: "Expression",
863 arg2: "Expression",
864 arg3: "Expression",
865 ) -> ct.StringType:
866 return ct.StringType()
869class StrPosition(Function): # pylint: disable=abstract-method
870 """
871 Returns the position of the first occurrence of a substring in a string column or expression.
872 """
875@StrPosition.register
876def infer_type( # noqa: F811 # pylint: disable=function-redefined # pragma: no cover
877 arg1: ct.StringType,
878 arg2: ct.StringType,
879) -> ct.IntegerType:
880 return ct.IntegerType() # pragma: no cover
883class StrToDate(Function): # pragma: no cover
884 """
885 Converts a string in a specified format to a date.
886 """
888 @staticmethod
889 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.DateType:
890 return ct.DateType()
893class StrToTime(Function): # pragma: no cover
894 """
895 Converts a string in a specified format to a timestamp.
896 """
898 @staticmethod
899 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.TimestampType:
900 return ct.TimestampType()
903class Sqrt(Function):
904 """
905 Computes the square root of a numeric column or expression.
906 """
908 @staticmethod
909 def infer_type(arg: "Expression") -> ct.DoubleType:
910 return ct.DoubleType()
913class Stddev(Function):
914 """
915 Computes the sample standard deviation of a numerical column or expression.
916 """
918 is_aggregation = True
920 @staticmethod
921 def infer_type(arg: "Expression") -> ct.DoubleType:
922 return ct.DoubleType()
925class StddevPop(Function): # pragma: no cover
926 """
927 Computes the population standard deviation of the input column or expression.
928 """
930 is_aggregation = True
932 @staticmethod
933 def infer_type(arg: "Expression") -> ct.DoubleType:
934 return ct.DoubleType()
937class StddevSamp(Function): # pragma: no cover
938 """
939 Computes the sample standard deviation of the input column or expression.
940 """
942 is_aggregation = True
944 @staticmethod
945 def infer_type(arg: "Expression") -> ct.DoubleType:
946 return ct.DoubleType()
949class TimeToStr(Function): # pragma: no cover
950 """
951 Converts a time value to a string using the specified format.
952 """
954 @staticmethod
955 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.StringType:
956 return ct.StringType()
959class TimeToTimeStr(Function): # pragma: no cover
960 """
961 Converts a time value to a string using the specified format.
962 """
964 @staticmethod
965 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.StringType:
966 return ct.StringType()
969class TimeStrToDate(Function): # pragma: no cover
970 """
971 Converts a string value to a date.
972 """
974 @staticmethod
975 def infer_type(arg: "Expression") -> ct.DateType:
976 return ct.DateType()
979class TimeStrToTime(Function): # pragma: no cover
980 """
981 Converts a string value to a time.
982 """
984 @staticmethod
985 def infer_type(arg: "Expression") -> ct.TimestampType:
986 return ct.TimestampType()
989class Trim(Function): # pragma: no cover
990 """
991 Removes leading and trailing whitespace from a string value.
992 """
994 @staticmethod
995 def infer_type(arg: "Expression") -> ct.StringType:
996 return ct.StringType()
999class TsOrDsToDateStr(Function): # pragma: no cover
1000 """
1001 Converts a timestamp or date value to a string using the specified format.
1002 """
1004 @staticmethod
1005 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.StringType:
1006 return ct.StringType()
1009class TsOrDsToDate(Function): # pragma: no cover
1010 """
1011 Converts a timestamp or date value to a date.
1012 """
1014 @staticmethod
1015 def infer_type(arg: "Expression") -> ct.DateType:
1016 return ct.DateType()
1019class TsOrDiToDi(Function): # pragma: no cover
1020 """
1021 Converts a timestamp or date value to a date.
1022 """
1024 @staticmethod
1025 def infer_type(arg: "Expression") -> ct.IntegerType:
1026 return ct.IntegerType()
1029class UnixToStr(Function): # pragma: no cover
1030 """
1031 Converts a Unix timestamp to a string using the specified format.
1032 """
1034 @staticmethod
1035 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.StringType:
1036 return ct.StringType()
1039class UnixToTime(Function): # pragma: no cover
1040 """
1041 Converts a Unix timestamp to a time.
1042 """
1044 @staticmethod
1045 def infer_type(arg: "Expression") -> ct.TimestampType:
1046 return ct.TimestampType()
1049class UnixToTimeStr(Function): # pragma: no cover
1050 """
1051 Converts a Unix timestamp to a string using the specified format.
1052 """
1054 @staticmethod
1055 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.StringType:
1056 return ct.StringType()
1059class Upper(Function): # pragma: no cover
1060 """
1061 Converts a string value to uppercase.
1062 """
1064 @staticmethod
1065 def infer_type(arg: "Expression") -> ct.StringType:
1066 return ct.StringType()
1069class Variance(Function): # pragma: no cover
1070 """
1071 Computes the sample variance of the input column or expression.
1072 """
1074 is_aggregation = True
1076 @staticmethod
1077 def infer_type(arg: "Expression") -> ct.DoubleType:
1078 return ct.DoubleType()
1081class VariancePop(Function): # pragma: no cover
1082 """
1083 Computes the population variance of the input column or expression.
1084 """
1086 is_aggregation = True
1088 @staticmethod
1089 def infer_type(arg: "Expression") -> ct.DoubleType:
1090 return ct.DoubleType()
1093class Array(Function): # pylint: disable=abstract-method
1094 """
1095 Returns an array of constants
1096 """
1099@Array.register # type: ignore
1100def infer_type( # noqa: F811 # pylint: disable=function-redefined
1101 *elements: ct.ColumnType,
1102) -> ct.ListType:
1103 types = {element.type for element in elements}
1104 if len(types) > 1:
1105 raise DJParseException(
1106 f"Multiple types {', '.join(sorted(str(typ) for typ in types))} passed to array.",
1107 )
1108 element_type = elements[0].type if elements else ct.NullType()
1109 return ct.ListType(element_type=element_type)
1112class Map(Function): # pylint: disable=abstract-method
1113 """
1114 Returns a map of constants
1115 """
1118def extract_consistent_type(elements):
1119 """
1120 Check if all elements are the same type and return that type.
1121 """
1122 if all(isinstance(element.type, ct.IntegerType) for element in elements):
1123 return ct.IntegerType()
1124 if all(isinstance(element.type, ct.DoubleType) for element in elements):
1125 return ct.DoubleType()
1126 if all(isinstance(element.type, ct.FloatType) for element in elements):
1127 return ct.FloatType()
1128 return ct.StringType()
1131@Map.register # type: ignore
1132def infer_type( # noqa: F811 # pylint: disable=function-redefined
1133 *elements: ct.ColumnType,
1134) -> ct.MapType:
1135 keys = elements[0::2]
1136 values = elements[1::2]
1137 if len(keys) != len(values):
1138 raise DJParseException("Different number of keys and values for MAP.")
1140 key_type = extract_consistent_type(keys)
1141 value_type = extract_consistent_type(values)
1142 return ct.MapType(key_type=key_type, value_type=value_type)
1145class Week(Function):
1146 """
1147 Returns the week number of the year of the input date value.
1148 """
1150 @staticmethod
1151 def infer_type(arg: "Expression") -> ct.TinyIntType:
1152 return ct.TinyIntType()
1155class Year(Function):
1156 """
1157 Returns the year of the input date value.
1158 """
1160 @staticmethod
1161 def infer_type(arg: "Expression") -> ct.TinyIntType:
1162 return ct.TinyIntType()
1165class FromJson(Function): # pragma: no cover # pylint: disable=abstract-method
1166 """
1167 Converts a JSON string to a struct or map.
1168 """
1171@FromJson.register # type: ignore
1172def infer_type( # noqa: F811 # pylint: disable=function-redefined # pragma: no cover
1173 json: ct.StringType,
1174 schema: ct.StringType,
1175 options: Optional[Function] = None,
1176) -> ct.StructType:
1177 # TODO: Handle options? # pylint: disable=fixme
1178 # pylint: disable=import-outside-toplevel
1179 from dj.sql.parsing.backends.antlr4 import parse_rule # pragma: no cover
1181 return ct.StructType(
1182 *parse_rule(schema.value, "complexColTypeList")
1183 ) # pragma: no cover
1186class FunctionRegistryDict(dict):
1187 """
1188 Custom dictionary mapping for functions
1189 """
1191 def __getitem__(self, key):
1192 """
1193 Returns a custom error about functions that haven't been implemented yet.
1194 """
1195 try:
1196 return super().__getitem__(key)
1197 except KeyError as exc:
1198 raise DJNotImplementedException(
1199 f"The function `{key}` hasn't been implemented in "
1200 "DJ yet. You can file an issue at https://github."
1201 "com/DataJunction/dj/issues/new?title=Function+"
1202 f"missing:+{key} to request it to be added, or use "
1203 "the documentation at https://github.com/DataJunct"
1204 "ion/dj/blob/main/docs/functions.rst to implement it.",
1205 ) from exc
1208# https://spark.apache.org/docs/3.3.2/sql-ref-syntax-qry-select-tvf.html#content
1209class Explode(TableFunction): # pylint: disable=abstract-method
1210 """
1211 The Explode function is used to explode the specified array,
1212 nested array, or map column into multiple rows.
1213 The explode function will generate a new row for each
1214 element in the specified column.
1215 """
1218@Explode.register
1219def infer_type( # noqa: F811 # pylint: disable=function-redefined
1220 arg: ct.ListType,
1221) -> List[ct.NestedField]:
1222 return [arg.element]
1225@Explode.register
1226def infer_type( # noqa: F811 # pylint: disable=function-redefined
1227 arg: ct.MapType,
1228) -> List[ct.NestedField]:
1229 return [arg.key, arg.value]
1232class Unnest(TableFunction): # pylint: disable=abstract-method
1233 """
1234 The unnest function is used to explode the specified array,
1235 nested array, or map column into multiple rows.
1236 It will generate a new row for each element in the specified column.
1237 """
1240@Unnest.register
1241def infer_type( # noqa: F811 # pylint: disable=function-redefined
1242 arg: ct.ListType,
1243) -> List[ct.NestedField]:
1244 return [arg.element] # pragma: no cover
1247@Unnest.register
1248def infer_type( # noqa: F811 # pylint: disable=function-redefined
1249 arg: ct.MapType,
1250) -> List[ct.NestedField]:
1251 return [arg.key, arg.value]
1254function_registry = FunctionRegistryDict()
1255for cls in Function.__subclasses__():
1256 snake_cased = re.sub(r"(?<!^)(?=[A-Z])", "_", cls.__name__)
1257 function_registry[cls.__name__.upper()] = cls
1258 function_registry[snake_cased.upper()] = cls
1261table_function_registry = FunctionRegistryDict()
1262for cls in TableFunction.__subclasses__():
1263 snake_cased = re.sub(r"(?<!^)(?=[A-Z])", "_", cls.__name__)
1264 table_function_registry[cls.__name__.upper()] = cls
1265 table_function_registry[snake_cased.upper()] = cls