Coverage for dj/construction/utils.py: 100%

42 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1""" 

2Utilities used around construction 

3""" 

4 

5from string import ascii_letters, digits 

6from typing import TYPE_CHECKING, List, Optional, Set 

7 

8from sqlalchemy.orm.exc import NoResultFound 

9from sqlmodel import Session, select 

10 

11from dj.errors import DJError, DJErrorException, ErrorCode 

12from dj.models.node import Node, NodeRevision, NodeType 

13 

14if TYPE_CHECKING: 

15 from dj.sql.parsing.ast import Name 

16 

17 

18def get_dj_node( 

19 session: Session, 

20 node_name: str, 

21 kinds: Optional[Set[NodeType]] = None, 

22) -> NodeRevision: 

23 """Return the DJ Node with a given name from a set of node types""" 

24 query = select(Node).filter(Node.name == node_name) 

25 if kinds: 

26 query = query.filter(Node.type.in_(kinds)) # type: ignore # pylint: disable=no-member 

27 match = None 

28 try: 

29 match = session.exec(query).one() 

30 except NoResultFound as no_result_exc: 

31 kind_msg = " or ".join(str(k) for k in kinds) if kinds else "" 

32 raise DJErrorException( 

33 DJError( 

34 code=ErrorCode.UNKNOWN_NODE, 

35 message=f"No node `{node_name}` exists of kind {kind_msg}.", 

36 ), 

37 ) from no_result_exc 

38 return match.current if match else match 

39 

40 

41ACCEPTABLE_CHARS = set(ascii_letters + digits + "_") 

42LOOKUP_CHARS = { 

43 ".": "DOT", 

44 "'": "QUOTE", 

45 '"': "DQUOTE", 

46 "`": "BTICK", 

47 "!": "EXCL", 

48 "@": "AT", 

49 "#": "HASH", 

50 "$": "DOLLAR", 

51 "%": "PERC", 

52 "^": "CARAT", 

53 "&": "AMP", 

54 "*": "STAR", 

55 "(": "LPAREN", 

56 ")": "RPAREN", 

57 "[": "LBRACK", 

58 "]": "RBRACK", 

59 "-": "MINUS", 

60 "+": "PLUS", 

61 "=": "EQ", 

62 "/": "FSLSH", 

63 "\\": "BSLSH", 

64 "|": "PIPE", 

65 "~": "TILDE", 

66} 

67 

68 

69def amenable_name(name: str) -> str: 

70 """Takes a string and makes it have only alphanumerics""" 

71 ret: List[str] = [] 

72 cont: List[str] = [] 

73 for char in name: 

74 if char in ACCEPTABLE_CHARS: 

75 cont.append(char) 

76 else: 

77 ret.append("".join(cont)) 

78 ret.append(LOOKUP_CHARS.get(char, "UNK")) 

79 cont = [] 

80 

81 return ("_".join(ret) + "_" + "".join(cont)).strip("_") 

82 

83 

84def to_namespaced_name(name: str) -> "Name": 

85 """ 

86 Builds a namespaced name from a string 

87 """ 

88 from dj.sql.parsing.ast import Name # pylint: disable=import-outside-toplevel 

89 

90 chunked = name.split(".") 

91 chunked.reverse() 

92 current_name = None 

93 full_name = None 

94 for chunk in chunked: 

95 if not current_name: 

96 current_name = Name(chunk) 

97 full_name = current_name 

98 else: 

99 current_name.namespace = Name(chunk) 

100 current_name = current_name.namespace 

101 return full_name # type: ignore