sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import TIMEZONES, format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19 20class Dialects(str, Enum): 21 DIALECT = "" 22 23 BIGQUERY = "bigquery" 24 CLICKHOUSE = "clickhouse" 25 DATABRICKS = "databricks" 26 DRILL = "drill" 27 DUCKDB = "duckdb" 28 HIVE = "hive" 29 MYSQL = "mysql" 30 ORACLE = "oracle" 31 POSTGRES = "postgres" 32 PRESTO = "presto" 33 REDSHIFT = "redshift" 34 SNOWFLAKE = "snowflake" 35 SPARK = "spark" 36 SPARK2 = "spark2" 37 SQLITE = "sqlite" 38 STARROCKS = "starrocks" 39 TABLEAU = "tableau" 40 TERADATA = "teradata" 41 TRINO = "trino" 42 TSQL = "tsql" 43 Doris = "doris" 44 45 46class _Dialect(type): 47 classes: t.Dict[str, t.Type[Dialect]] = {} 48 49 def __eq__(cls, other: t.Any) -> bool: 50 if cls is other: 51 return True 52 if isinstance(other, str): 53 return cls is cls.get(other) 54 if isinstance(other, Dialect): 55 return cls is type(other) 56 57 return False 58 59 def __hash__(cls) -> int: 60 return hash(cls.__name__.lower()) 61 62 @classmethod 63 def __getitem__(cls, key: str) -> t.Type[Dialect]: 64 return cls.classes[key] 65 66 @classmethod 67 def get( 68 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 69 ) -> t.Optional[t.Type[Dialect]]: 70 return cls.classes.get(key, default) 71 72 def __new__(cls, clsname, bases, attrs): 73 klass = super().__new__(cls, clsname, bases, attrs) 74 enum = Dialects.__members__.get(clsname.upper()) 75 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 76 77 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 78 klass.FORMAT_TRIE = ( 79 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 80 ) 81 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 82 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 83 84 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 85 86 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 87 klass.parser_class = getattr(klass, "Parser", Parser) 88 klass.generator_class = getattr(klass, "Generator", Generator) 89 90 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 91 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 92 klass.tokenizer_class._IDENTIFIERS.items() 93 )[0] 94 95 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 96 return next( 97 ( 98 (s, e) 99 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 100 if t == token_type 101 ), 102 (None, None), 103 ) 104 105 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 106 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 107 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 108 109 dialect_properties = { 110 **{ 111 k: v 112 for k, v in vars(klass).items() 113 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 114 }, 115 "TOKENIZER_CLASS": klass.tokenizer_class, 116 } 117 118 if enum not in ("", "bigquery"): 119 dialect_properties["SELECT_KINDS"] = () 120 121 # Pass required dialect properties to the tokenizer, parser and generator classes 122 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 123 for name, value in dialect_properties.items(): 124 if hasattr(subclass, name): 125 setattr(subclass, name, value) 126 127 if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: 128 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 129 130 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 131 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 132 TokenType.ANTI, 133 TokenType.SEMI, 134 } 135 136 klass.generator_class.can_identify = klass.can_identify 137 138 return klass 139 140 141class Dialect(metaclass=_Dialect): 142 # Determines the base index offset for arrays 143 INDEX_OFFSET = 0 144 145 # If true unnest table aliases are considered only as column aliases 146 UNNEST_COLUMN_ONLY = False 147 148 # Determines whether or not the table alias comes after tablesample 149 ALIAS_POST_TABLESAMPLE = False 150 151 # Determines whether or not unquoted identifiers are resolved as uppercase 152 # When set to None, it means that the dialect treats all identifiers as case-insensitive 153 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 154 155 # Determines whether or not an unquoted identifier can start with a digit 156 IDENTIFIERS_CAN_START_WITH_DIGIT = False 157 158 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 159 DPIPE_IS_STRING_CONCAT = True 160 161 # Determines whether or not CONCAT's arguments must be strings 162 STRICT_STRING_CONCAT = False 163 164 # Determines whether or not user-defined data types are supported 165 SUPPORTS_USER_DEFINED_TYPES = True 166 167 # Determines whether or not SEMI/ANTI JOINs are supported 168 SUPPORTS_SEMI_ANTI_JOIN = True 169 170 # Determines how function names are going to be normalized 171 NORMALIZE_FUNCTIONS: bool | str = "upper" 172 173 # Determines whether the base comes first in the LOG function 174 LOG_BASE_FIRST = True 175 176 # Indicates the default null ordering method to use if not explicitly set 177 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 178 NULL_ORDERING = "nulls_are_small" 179 180 DATE_FORMAT = "'%Y-%m-%d'" 181 DATEINT_FORMAT = "'%Y%m%d'" 182 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 183 184 # Custom time mappings in which the key represents dialect time format 185 # and the value represents a python time format 186 TIME_MAPPING: t.Dict[str, str] = {} 187 188 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 189 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 190 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 191 FORMAT_MAPPING: t.Dict[str, str] = {} 192 193 # Mapping of an unescaped escape sequence to the corresponding character 194 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 195 196 # Columns that are auto-generated by the engine corresponding to this dialect 197 # Such columns may be excluded from SELECT * queries, for example 198 PSEUDOCOLUMNS: t.Set[str] = set() 199 200 # Autofilled 201 tokenizer_class = Tokenizer 202 parser_class = Parser 203 generator_class = Generator 204 205 # A trie of the time_mapping keys 206 TIME_TRIE: t.Dict = {} 207 FORMAT_TRIE: t.Dict = {} 208 209 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 210 INVERSE_TIME_TRIE: t.Dict = {} 211 212 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 213 214 def __eq__(self, other: t.Any) -> bool: 215 return type(self) == other 216 217 def __hash__(self) -> int: 218 return hash(type(self)) 219 220 @classmethod 221 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 222 if not dialect: 223 return cls 224 if isinstance(dialect, _Dialect): 225 return dialect 226 if isinstance(dialect, Dialect): 227 return dialect.__class__ 228 229 result = cls.get(dialect) 230 if not result: 231 raise ValueError(f"Unknown dialect '{dialect}'") 232 233 return result 234 235 @classmethod 236 def format_time( 237 cls, expression: t.Optional[str | exp.Expression] 238 ) -> t.Optional[exp.Expression]: 239 if isinstance(expression, str): 240 return exp.Literal.string( 241 # the time formats are quoted 242 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 243 ) 244 245 if expression and expression.is_string: 246 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 247 248 return expression 249 250 @classmethod 251 def normalize_identifier(cls, expression: E) -> E: 252 """ 253 Normalizes an unquoted identifier to either lower or upper case, thus essentially 254 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 255 they will be normalized to lowercase regardless of being quoted or not. 256 """ 257 if isinstance(expression, exp.Identifier) and ( 258 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 259 ): 260 expression.set( 261 "this", 262 expression.this.upper() 263 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 264 else expression.this.lower(), 265 ) 266 267 return expression 268 269 @classmethod 270 def case_sensitive(cls, text: str) -> bool: 271 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 272 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 273 return False 274 275 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 276 return any(unsafe(char) for char in text) 277 278 @classmethod 279 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 280 """Checks if text can be identified given an identify option. 281 282 Args: 283 text: The text to check. 284 identify: 285 "always" or `True`: Always returns true. 286 "safe": True if the identifier is case-insensitive. 287 288 Returns: 289 Whether or not the given text can be identified. 290 """ 291 if identify is True or identify == "always": 292 return True 293 294 if identify == "safe": 295 return not cls.case_sensitive(text) 296 297 return False 298 299 @classmethod 300 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 301 if isinstance(expression, exp.Identifier): 302 name = expression.this 303 expression.set( 304 "quoted", 305 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 306 ) 307 308 return expression 309 310 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 311 return self.parser(**opts).parse(self.tokenize(sql), sql) 312 313 def parse_into( 314 self, expression_type: exp.IntoType, sql: str, **opts 315 ) -> t.List[t.Optional[exp.Expression]]: 316 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 317 318 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 319 return self.generator(**opts).generate(expression) 320 321 def transpile(self, sql: str, **opts) -> t.List[str]: 322 return [self.generate(expression, **opts) for expression in self.parse(sql)] 323 324 def tokenize(self, sql: str) -> t.List[Token]: 325 return self.tokenizer.tokenize(sql) 326 327 @property 328 def tokenizer(self) -> Tokenizer: 329 if not hasattr(self, "_tokenizer"): 330 self._tokenizer = self.tokenizer_class() 331 return self._tokenizer 332 333 def parser(self, **opts) -> Parser: 334 return self.parser_class(**opts) 335 336 def generator(self, **opts) -> Generator: 337 return self.generator_class(**opts) 338 339 340DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 341 342 343def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 344 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 345 346 347def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 348 if expression.args.get("accuracy"): 349 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 350 return self.func("APPROX_COUNT_DISTINCT", expression.this) 351 352 353def if_sql( 354 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 355) -> t.Callable[[Generator, exp.If], str]: 356 def _if_sql(self: Generator, expression: exp.If) -> str: 357 return self.func( 358 name, 359 expression.this, 360 expression.args.get("true"), 361 expression.args.get("false") or false_value, 362 ) 363 364 return _if_sql 365 366 367def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 368 return self.binary(expression, "->") 369 370 371def arrow_json_extract_scalar_sql( 372 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 373) -> str: 374 return self.binary(expression, "->>") 375 376 377def inline_array_sql(self: Generator, expression: exp.Array) -> str: 378 return f"[{self.expressions(expression, flat=True)}]" 379 380 381def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 382 return self.like_sql( 383 exp.Like( 384 this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy() 385 ) 386 ) 387 388 389def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 390 zone = self.sql(expression, "this") 391 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 392 393 394def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 395 if expression.args.get("recursive"): 396 self.unsupported("Recursive CTEs are unsupported") 397 expression.args["recursive"] = False 398 return self.with_sql(expression) 399 400 401def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 402 n = self.sql(expression, "this") 403 d = self.sql(expression, "expression") 404 return f"IF({d} <> 0, {n} / {d}, NULL)" 405 406 407def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 408 self.unsupported("TABLESAMPLE unsupported") 409 return self.sql(expression.this) 410 411 412def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 413 self.unsupported("PIVOT unsupported") 414 return "" 415 416 417def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 418 return self.cast_sql(expression) 419 420 421def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 422 self.unsupported("Properties unsupported") 423 return "" 424 425 426def no_comment_column_constraint_sql( 427 self: Generator, expression: exp.CommentColumnConstraint 428) -> str: 429 self.unsupported("CommentColumnConstraint unsupported") 430 return "" 431 432 433def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 434 self.unsupported("MAP_FROM_ENTRIES unsupported") 435 return "" 436 437 438def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 439 this = self.sql(expression, "this") 440 substr = self.sql(expression, "substr") 441 position = self.sql(expression, "position") 442 if position: 443 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 444 return f"STRPOS({this}, {substr})" 445 446 447def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 448 return ( 449 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 450 ) 451 452 453def var_map_sql( 454 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 455) -> str: 456 keys = expression.args["keys"] 457 values = expression.args["values"] 458 459 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 460 self.unsupported("Cannot convert array columns into map.") 461 return self.func(map_func_name, keys, values) 462 463 args = [] 464 for key, value in zip(keys.expressions, values.expressions): 465 args.append(self.sql(key)) 466 args.append(self.sql(value)) 467 468 return self.func(map_func_name, *args) 469 470 471def format_time_lambda( 472 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 473) -> t.Callable[[t.List], E]: 474 """Helper used for time expressions. 475 476 Args: 477 exp_class: the expression class to instantiate. 478 dialect: target sql dialect. 479 default: the default format, True being time. 480 481 Returns: 482 A callable that can be used to return the appropriately formatted time expression. 483 """ 484 485 def _format_time(args: t.List): 486 return exp_class( 487 this=seq_get(args, 0), 488 format=Dialect[dialect].format_time( 489 seq_get(args, 1) 490 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 491 ), 492 ) 493 494 return _format_time 495 496 497def time_format( 498 dialect: DialectType = None, 499) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 500 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 501 """ 502 Returns the time format for a given expression, unless it's equivalent 503 to the default time format of the dialect of interest. 504 """ 505 time_format = self.format_time(expression) 506 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 507 508 return _time_format 509 510 511def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 512 """ 513 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 514 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 515 columns are removed from the create statement. 516 """ 517 has_schema = isinstance(expression.this, exp.Schema) 518 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 519 520 if has_schema and is_partitionable: 521 expression = expression.copy() 522 prop = expression.find(exp.PartitionedByProperty) 523 if prop and prop.this and not isinstance(prop.this, exp.Schema): 524 schema = expression.this 525 columns = {v.name.upper() for v in prop.this.expressions} 526 partitions = [col for col in schema.expressions if col.name.upper() in columns] 527 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 528 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 529 expression.set("this", schema) 530 531 return self.create_sql(expression) 532 533 534def parse_date_delta( 535 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 536) -> t.Callable[[t.List], E]: 537 def inner_func(args: t.List) -> E: 538 unit_based = len(args) == 3 539 this = args[2] if unit_based else seq_get(args, 0) 540 unit = args[0] if unit_based else exp.Literal.string("DAY") 541 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 542 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 543 544 return inner_func 545 546 547def parse_date_delta_with_interval( 548 expression_class: t.Type[E], 549) -> t.Callable[[t.List], t.Optional[E]]: 550 def func(args: t.List) -> t.Optional[E]: 551 if len(args) < 2: 552 return None 553 554 interval = args[1] 555 556 if not isinstance(interval, exp.Interval): 557 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 558 559 expression = interval.this 560 if expression and expression.is_string: 561 expression = exp.Literal.number(expression.this) 562 563 return expression_class( 564 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 565 ) 566 567 return func 568 569 570def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 571 unit = seq_get(args, 0) 572 this = seq_get(args, 1) 573 574 if isinstance(this, exp.Cast) and this.is_type("date"): 575 return exp.DateTrunc(unit=unit, this=this) 576 return exp.TimestampTrunc(this=this, unit=unit) 577 578 579def date_add_interval_sql( 580 data_type: str, kind: str 581) -> t.Callable[[Generator, exp.Expression], str]: 582 def func(self: Generator, expression: exp.Expression) -> str: 583 this = self.sql(expression, "this") 584 unit = expression.args.get("unit") 585 unit = exp.var(unit.name.upper() if unit else "DAY") 586 interval = exp.Interval(this=expression.expression.copy(), unit=unit) 587 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 588 589 return func 590 591 592def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 593 return self.func( 594 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 595 ) 596 597 598def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 599 if not expression.expression: 600 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 601 if expression.text("expression").lower() in TIMEZONES: 602 return self.sql( 603 exp.AtTimeZone( 604 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 605 zone=expression.expression, 606 ) 607 ) 608 return self.function_fallback_sql(expression) 609 610 611def locate_to_strposition(args: t.List) -> exp.Expression: 612 return exp.StrPosition( 613 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 614 ) 615 616 617def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 618 return self.func( 619 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 620 ) 621 622 623def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 624 expression = expression.copy() 625 return self.sql( 626 exp.Substring( 627 this=expression.this, start=exp.Literal.number(1), length=expression.expression 628 ) 629 ) 630 631 632def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 633 expression = expression.copy() 634 return self.sql( 635 exp.Substring( 636 this=expression.this, 637 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 638 ) 639 ) 640 641 642def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 643 return self.sql(exp.cast(expression.this, "timestamp")) 644 645 646def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 647 return self.sql(exp.cast(expression.this, "date")) 648 649 650# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 651def encode_decode_sql( 652 self: Generator, expression: exp.Expression, name: str, replace: bool = True 653) -> str: 654 charset = expression.args.get("charset") 655 if charset and charset.name.lower() != "utf-8": 656 self.unsupported(f"Expected utf-8 character set, got {charset}.") 657 658 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 659 660 661def min_or_least(self: Generator, expression: exp.Min) -> str: 662 name = "LEAST" if expression.expressions else "MIN" 663 return rename_func(name)(self, expression) 664 665 666def max_or_greatest(self: Generator, expression: exp.Max) -> str: 667 name = "GREATEST" if expression.expressions else "MAX" 668 return rename_func(name)(self, expression) 669 670 671def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 672 cond = expression.this 673 674 if isinstance(expression.this, exp.Distinct): 675 cond = expression.this.expressions[0] 676 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 677 678 return self.func("sum", exp.func("if", cond.copy(), 1, 0)) 679 680 681def trim_sql(self: Generator, expression: exp.Trim) -> str: 682 target = self.sql(expression, "this") 683 trim_type = self.sql(expression, "position") 684 remove_chars = self.sql(expression, "expression") 685 collation = self.sql(expression, "collation") 686 687 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 688 if not remove_chars and not collation: 689 return self.trim_sql(expression) 690 691 trim_type = f"{trim_type} " if trim_type else "" 692 remove_chars = f"{remove_chars} " if remove_chars else "" 693 from_part = "FROM " if trim_type or remove_chars else "" 694 collation = f" COLLATE {collation}" if collation else "" 695 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 696 697 698def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 699 return self.func("STRPTIME", expression.this, self.format_time(expression)) 700 701 702def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 703 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 704 _dialect = Dialect.get_or_raise(dialect) 705 time_format = self.format_time(expression) 706 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 707 return self.sql( 708 exp.cast( 709 exp.StrToTime(this=expression.this, format=expression.args["format"]), 710 "date", 711 ) 712 ) 713 return self.sql(exp.cast(expression.this, "date")) 714 715 return _ts_or_ds_to_date_sql 716 717 718def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 719 expression = expression.copy() 720 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 721 722 723def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 724 expression = expression.copy() 725 delim, *rest_args = expression.expressions 726 return self.sql( 727 reduce( 728 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 729 rest_args, 730 ) 731 ) 732 733 734def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 735 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 736 if bad_args: 737 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 738 739 return self.func( 740 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 741 ) 742 743 744def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 745 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 746 if bad_args: 747 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 748 749 return self.func( 750 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 751 ) 752 753 754def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 755 names = [] 756 for agg in aggregations: 757 if isinstance(agg, exp.Alias): 758 names.append(agg.alias) 759 else: 760 """ 761 This case corresponds to aggregations without aliases being used as suffixes 762 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 763 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 764 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 765 """ 766 agg_all_unquoted = agg.transform( 767 lambda node: exp.Identifier(this=node.name, quoted=False) 768 if isinstance(node, exp.Identifier) 769 else node 770 ) 771 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 772 773 return names 774 775 776def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 777 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 778 779 780# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 781def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 782 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 783 784 785def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 786 return self.func("MAX", expression.this) 787 788 789def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 790 a = self.sql(expression.left) 791 b = self.sql(expression.right) 792 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 793 794 795# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon 796def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: 797 return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" 798 799 800def is_parse_json(expression: exp.Expression) -> bool: 801 return isinstance(expression, exp.ParseJSON) or ( 802 isinstance(expression, exp.Cast) and expression.is_type("json") 803 ) 804 805 806def isnull_to_is_null(args: t.List) -> exp.Expression: 807 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 808 809 810def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str: 811 if expression.expression.args.get("with"): 812 expression = expression.copy() 813 expression.set("with", expression.expression.args["with"].pop()) 814 return self.insert_sql(expression) 815 816 817def generatedasidentitycolumnconstraint_sql( 818 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 819) -> str: 820 start = self.sql(expression, "start") or "1" 821 increment = self.sql(expression, "increment") or "1" 822 return f"IDENTITY({start}, {increment})"
21class Dialects(str, Enum): 22 DIALECT = "" 23 24 BIGQUERY = "bigquery" 25 CLICKHOUSE = "clickhouse" 26 DATABRICKS = "databricks" 27 DRILL = "drill" 28 DUCKDB = "duckdb" 29 HIVE = "hive" 30 MYSQL = "mysql" 31 ORACLE = "oracle" 32 POSTGRES = "postgres" 33 PRESTO = "presto" 34 REDSHIFT = "redshift" 35 SNOWFLAKE = "snowflake" 36 SPARK = "spark" 37 SPARK2 = "spark2" 38 SQLITE = "sqlite" 39 STARROCKS = "starrocks" 40 TABLEAU = "tableau" 41 TERADATA = "teradata" 42 TRINO = "trino" 43 TSQL = "tsql" 44 Doris = "doris"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
142class Dialect(metaclass=_Dialect): 143 # Determines the base index offset for arrays 144 INDEX_OFFSET = 0 145 146 # If true unnest table aliases are considered only as column aliases 147 UNNEST_COLUMN_ONLY = False 148 149 # Determines whether or not the table alias comes after tablesample 150 ALIAS_POST_TABLESAMPLE = False 151 152 # Determines whether or not unquoted identifiers are resolved as uppercase 153 # When set to None, it means that the dialect treats all identifiers as case-insensitive 154 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 155 156 # Determines whether or not an unquoted identifier can start with a digit 157 IDENTIFIERS_CAN_START_WITH_DIGIT = False 158 159 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 160 DPIPE_IS_STRING_CONCAT = True 161 162 # Determines whether or not CONCAT's arguments must be strings 163 STRICT_STRING_CONCAT = False 164 165 # Determines whether or not user-defined data types are supported 166 SUPPORTS_USER_DEFINED_TYPES = True 167 168 # Determines whether or not SEMI/ANTI JOINs are supported 169 SUPPORTS_SEMI_ANTI_JOIN = True 170 171 # Determines how function names are going to be normalized 172 NORMALIZE_FUNCTIONS: bool | str = "upper" 173 174 # Determines whether the base comes first in the LOG function 175 LOG_BASE_FIRST = True 176 177 # Indicates the default null ordering method to use if not explicitly set 178 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 179 NULL_ORDERING = "nulls_are_small" 180 181 DATE_FORMAT = "'%Y-%m-%d'" 182 DATEINT_FORMAT = "'%Y%m%d'" 183 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 184 185 # Custom time mappings in which the key represents dialect time format 186 # and the value represents a python time format 187 TIME_MAPPING: t.Dict[str, str] = {} 188 189 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 190 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 191 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 192 FORMAT_MAPPING: t.Dict[str, str] = {} 193 194 # Mapping of an unescaped escape sequence to the corresponding character 195 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 196 197 # Columns that are auto-generated by the engine corresponding to this dialect 198 # Such columns may be excluded from SELECT * queries, for example 199 PSEUDOCOLUMNS: t.Set[str] = set() 200 201 # Autofilled 202 tokenizer_class = Tokenizer 203 parser_class = Parser 204 generator_class = Generator 205 206 # A trie of the time_mapping keys 207 TIME_TRIE: t.Dict = {} 208 FORMAT_TRIE: t.Dict = {} 209 210 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 211 INVERSE_TIME_TRIE: t.Dict = {} 212 213 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 214 215 def __eq__(self, other: t.Any) -> bool: 216 return type(self) == other 217 218 def __hash__(self) -> int: 219 return hash(type(self)) 220 221 @classmethod 222 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 223 if not dialect: 224 return cls 225 if isinstance(dialect, _Dialect): 226 return dialect 227 if isinstance(dialect, Dialect): 228 return dialect.__class__ 229 230 result = cls.get(dialect) 231 if not result: 232 raise ValueError(f"Unknown dialect '{dialect}'") 233 234 return result 235 236 @classmethod 237 def format_time( 238 cls, expression: t.Optional[str | exp.Expression] 239 ) -> t.Optional[exp.Expression]: 240 if isinstance(expression, str): 241 return exp.Literal.string( 242 # the time formats are quoted 243 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 244 ) 245 246 if expression and expression.is_string: 247 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 248 249 return expression 250 251 @classmethod 252 def normalize_identifier(cls, expression: E) -> E: 253 """ 254 Normalizes an unquoted identifier to either lower or upper case, thus essentially 255 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 256 they will be normalized to lowercase regardless of being quoted or not. 257 """ 258 if isinstance(expression, exp.Identifier) and ( 259 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 260 ): 261 expression.set( 262 "this", 263 expression.this.upper() 264 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 265 else expression.this.lower(), 266 ) 267 268 return expression 269 270 @classmethod 271 def case_sensitive(cls, text: str) -> bool: 272 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 273 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 274 return False 275 276 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 277 return any(unsafe(char) for char in text) 278 279 @classmethod 280 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 281 """Checks if text can be identified given an identify option. 282 283 Args: 284 text: The text to check. 285 identify: 286 "always" or `True`: Always returns true. 287 "safe": True if the identifier is case-insensitive. 288 289 Returns: 290 Whether or not the given text can be identified. 291 """ 292 if identify is True or identify == "always": 293 return True 294 295 if identify == "safe": 296 return not cls.case_sensitive(text) 297 298 return False 299 300 @classmethod 301 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 302 if isinstance(expression, exp.Identifier): 303 name = expression.this 304 expression.set( 305 "quoted", 306 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 307 ) 308 309 return expression 310 311 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 312 return self.parser(**opts).parse(self.tokenize(sql), sql) 313 314 def parse_into( 315 self, expression_type: exp.IntoType, sql: str, **opts 316 ) -> t.List[t.Optional[exp.Expression]]: 317 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 318 319 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 320 return self.generator(**opts).generate(expression) 321 322 def transpile(self, sql: str, **opts) -> t.List[str]: 323 return [self.generate(expression, **opts) for expression in self.parse(sql)] 324 325 def tokenize(self, sql: str) -> t.List[Token]: 326 return self.tokenizer.tokenize(sql) 327 328 @property 329 def tokenizer(self) -> Tokenizer: 330 if not hasattr(self, "_tokenizer"): 331 self._tokenizer = self.tokenizer_class() 332 return self._tokenizer 333 334 def parser(self, **opts) -> Parser: 335 return self.parser_class(**opts) 336 337 def generator(self, **opts) -> Generator: 338 return self.generator_class(**opts)
221 @classmethod 222 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 223 if not dialect: 224 return cls 225 if isinstance(dialect, _Dialect): 226 return dialect 227 if isinstance(dialect, Dialect): 228 return dialect.__class__ 229 230 result = cls.get(dialect) 231 if not result: 232 raise ValueError(f"Unknown dialect '{dialect}'") 233 234 return result
236 @classmethod 237 def format_time( 238 cls, expression: t.Optional[str | exp.Expression] 239 ) -> t.Optional[exp.Expression]: 240 if isinstance(expression, str): 241 return exp.Literal.string( 242 # the time formats are quoted 243 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 244 ) 245 246 if expression and expression.is_string: 247 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 248 249 return expression
251 @classmethod 252 def normalize_identifier(cls, expression: E) -> E: 253 """ 254 Normalizes an unquoted identifier to either lower or upper case, thus essentially 255 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 256 they will be normalized to lowercase regardless of being quoted or not. 257 """ 258 if isinstance(expression, exp.Identifier) and ( 259 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 260 ): 261 expression.set( 262 "this", 263 expression.this.upper() 264 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 265 else expression.this.lower(), 266 ) 267 268 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized to lowercase regardless of being quoted or not.
270 @classmethod 271 def case_sensitive(cls, text: str) -> bool: 272 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 273 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 274 return False 275 276 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 277 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
279 @classmethod 280 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 281 """Checks if text can be identified given an identify option. 282 283 Args: 284 text: The text to check. 285 identify: 286 "always" or `True`: Always returns true. 287 "safe": True if the identifier is case-insensitive. 288 289 Returns: 290 Whether or not the given text can be identified. 291 """ 292 if identify is True or identify == "always": 293 return True 294 295 if identify == "safe": 296 return not cls.case_sensitive(text) 297 298 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
300 @classmethod 301 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 302 if isinstance(expression, exp.Identifier): 303 name = expression.this 304 expression.set( 305 "quoted", 306 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 307 ) 308 309 return expression
354def if_sql( 355 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 356) -> t.Callable[[Generator, exp.If], str]: 357 def _if_sql(self: Generator, expression: exp.If) -> str: 358 return self.func( 359 name, 360 expression.this, 361 expression.args.get("true"), 362 expression.args.get("false") or false_value, 363 ) 364 365 return _if_sql
439def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 440 this = self.sql(expression, "this") 441 substr = self.sql(expression, "substr") 442 position = self.sql(expression, "position") 443 if position: 444 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 445 return f"STRPOS({this}, {substr})"
454def var_map_sql( 455 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 456) -> str: 457 keys = expression.args["keys"] 458 values = expression.args["values"] 459 460 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 461 self.unsupported("Cannot convert array columns into map.") 462 return self.func(map_func_name, keys, values) 463 464 args = [] 465 for key, value in zip(keys.expressions, values.expressions): 466 args.append(self.sql(key)) 467 args.append(self.sql(value)) 468 469 return self.func(map_func_name, *args)
472def format_time_lambda( 473 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 474) -> t.Callable[[t.List], E]: 475 """Helper used for time expressions. 476 477 Args: 478 exp_class: the expression class to instantiate. 479 dialect: target sql dialect. 480 default: the default format, True being time. 481 482 Returns: 483 A callable that can be used to return the appropriately formatted time expression. 484 """ 485 486 def _format_time(args: t.List): 487 return exp_class( 488 this=seq_get(args, 0), 489 format=Dialect[dialect].format_time( 490 seq_get(args, 1) 491 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 492 ), 493 ) 494 495 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
498def time_format( 499 dialect: DialectType = None, 500) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 501 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 502 """ 503 Returns the time format for a given expression, unless it's equivalent 504 to the default time format of the dialect of interest. 505 """ 506 time_format = self.format_time(expression) 507 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 508 509 return _time_format
512def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 513 """ 514 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 515 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 516 columns are removed from the create statement. 517 """ 518 has_schema = isinstance(expression.this, exp.Schema) 519 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 520 521 if has_schema and is_partitionable: 522 expression = expression.copy() 523 prop = expression.find(exp.PartitionedByProperty) 524 if prop and prop.this and not isinstance(prop.this, exp.Schema): 525 schema = expression.this 526 columns = {v.name.upper() for v in prop.this.expressions} 527 partitions = [col for col in schema.expressions if col.name.upper() in columns] 528 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 529 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 530 expression.set("this", schema) 531 532 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
535def parse_date_delta( 536 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 537) -> t.Callable[[t.List], E]: 538 def inner_func(args: t.List) -> E: 539 unit_based = len(args) == 3 540 this = args[2] if unit_based else seq_get(args, 0) 541 unit = args[0] if unit_based else exp.Literal.string("DAY") 542 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 543 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 544 545 return inner_func
548def parse_date_delta_with_interval( 549 expression_class: t.Type[E], 550) -> t.Callable[[t.List], t.Optional[E]]: 551 def func(args: t.List) -> t.Optional[E]: 552 if len(args) < 2: 553 return None 554 555 interval = args[1] 556 557 if not isinstance(interval, exp.Interval): 558 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 559 560 expression = interval.this 561 if expression and expression.is_string: 562 expression = exp.Literal.number(expression.this) 563 564 return expression_class( 565 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 566 ) 567 568 return func
580def date_add_interval_sql( 581 data_type: str, kind: str 582) -> t.Callable[[Generator, exp.Expression], str]: 583 def func(self: Generator, expression: exp.Expression) -> str: 584 this = self.sql(expression, "this") 585 unit = expression.args.get("unit") 586 unit = exp.var(unit.name.upper() if unit else "DAY") 587 interval = exp.Interval(this=expression.expression.copy(), unit=unit) 588 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 589 590 return func
599def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 600 if not expression.expression: 601 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 602 if expression.text("expression").lower() in TIMEZONES: 603 return self.sql( 604 exp.AtTimeZone( 605 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 606 zone=expression.expression, 607 ) 608 ) 609 return self.function_fallback_sql(expression)
652def encode_decode_sql( 653 self: Generator, expression: exp.Expression, name: str, replace: bool = True 654) -> str: 655 charset = expression.args.get("charset") 656 if charset and charset.name.lower() != "utf-8": 657 self.unsupported(f"Expected utf-8 character set, got {charset}.") 658 659 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
672def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 673 cond = expression.this 674 675 if isinstance(expression.this, exp.Distinct): 676 cond = expression.this.expressions[0] 677 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 678 679 return self.func("sum", exp.func("if", cond.copy(), 1, 0))
682def trim_sql(self: Generator, expression: exp.Trim) -> str: 683 target = self.sql(expression, "this") 684 trim_type = self.sql(expression, "position") 685 remove_chars = self.sql(expression, "expression") 686 collation = self.sql(expression, "collation") 687 688 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 689 if not remove_chars and not collation: 690 return self.trim_sql(expression) 691 692 trim_type = f"{trim_type} " if trim_type else "" 693 remove_chars = f"{remove_chars} " if remove_chars else "" 694 from_part = "FROM " if trim_type or remove_chars else "" 695 collation = f" COLLATE {collation}" if collation else "" 696 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
703def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 704 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 705 _dialect = Dialect.get_or_raise(dialect) 706 time_format = self.format_time(expression) 707 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 708 return self.sql( 709 exp.cast( 710 exp.StrToTime(this=expression.this, format=expression.args["format"]), 711 "date", 712 ) 713 ) 714 return self.sql(exp.cast(expression.this, "date")) 715 716 return _ts_or_ds_to_date_sql
724def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 725 expression = expression.copy() 726 delim, *rest_args = expression.expressions 727 return self.sql( 728 reduce( 729 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 730 rest_args, 731 ) 732 )
735def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 736 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 737 if bad_args: 738 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 739 740 return self.func( 741 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 742 )
745def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 746 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 747 if bad_args: 748 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 749 750 return self.func( 751 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 752 )
755def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 756 names = [] 757 for agg in aggregations: 758 if isinstance(agg, exp.Alias): 759 names.append(agg.alias) 760 else: 761 """ 762 This case corresponds to aggregations without aliases being used as suffixes 763 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 764 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 765 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 766 """ 767 agg_all_unquoted = agg.transform( 768 lambda node: exp.Identifier(this=node.name, quoted=False) 769 if isinstance(node, exp.Identifier) 770 else node 771 ) 772 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 773 774 return names