Coverage for dj/construction/dj_query.py: 100%

52 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1""" 

2Functions for making queries directly against DJ 

3""" 

4 

5from typing import List, Optional, Set, cast 

6 

7from sqlmodel import Session 

8 

9from dj.construction.build import build_ast 

10from dj.construction.utils import amenable_name, get_dj_node 

11from dj.errors import DJErrorException 

12from dj.models.node import NodeRevision, NodeType 

13from dj.sql.parsing.backends.antlr4 import ast, parse 

14from dj.sql.parsing.backends.exceptions import DJParseException 

15 

16 

17def try_get_dj_node( 

18 session: Session, 

19 name: str, 

20 kinds: Set[NodeType], 

21) -> Optional[NodeRevision]: 

22 "wraps get dj node to return None if no node is found" 

23 try: 

24 return get_dj_node(session, name, kinds) 

25 except DJErrorException: 

26 return None 

27 

28 

29def _resolve_metric_nodes(session, col): 

30 """ 

31 Check if a column is a metric and modify the 

32 select accordingly 

33 """ 

34 joins = [] 

35 col_name = col.identifier(False) 

36 if metric_node := try_get_dj_node( 

37 session, 

38 col_name, 

39 {NodeType.METRIC}, 

40 ): 

41 # if we found a metric node we need to check where it came from 

42 parent_select = cast(ast.Select, col.get_nearest_parent_of_type(ast.Select)) 

43 if not getattr( 

44 parent_select, 

45 "_validated", 

46 False, 

47 ): # pragma: no cover 

48 if ( 

49 len(parent_select.from_.relations) != 1 

50 or parent_select.from_.relations[0].primary.alias_or_name.name 

51 != "metrics" 

52 ): 

53 raise DJParseException( 

54 "Any SELECT referencing a Metric must source " 

55 "from a single unaliased Table named `metrics`.", 

56 ) 

57 metrics_ref = parent_select.from_.relations[0].primary 

58 try: 

59 metrics_ref_name = metrics_ref.alias_or_name.identifier(False) 

60 except AttributeError: # pragma: no cover 

61 metrics_ref_name = "" 

62 if metrics_ref_name != "metrics": 

63 raise DJParseException( 

64 "The name of the table in a Metric query must be `metrics`.", 

65 ) 

66 parent_select.from_ = ast.From( 

67 [], 

68 ) # clear the FROM to prep it for the actual tables 

69 parent_select._validated = True # pylint: disable=W0212 

70 

71 # we have a metric from `metrics` 

72 metric_name = amenable_name(metric_node.name) 

73 metric_select = parse( # pylint: disable=W0212 

74 cast(str, metric_node.query), 

75 ).select 

76 tables = metric_select.from_.find_all(ast.Table) 

77 metric_table_expression = ast.Alias( 

78 ast.Name(metric_name), 

79 None, 

80 metric_select, 

81 ) 

82 

83 for table in tables: 

84 joins += _hoist_metric_source_tables( 

85 session, 

86 table, 

87 metric_select, 

88 metric_table_expression, 

89 ) 

90 

91 metric_column = ast.Column( 

92 ast.Name(metric_node.columns[0].name), 

93 _table=metric_table_expression, 

94 as_=True, 

95 ) 

96 

97 metric_table_expression.child.parenthesized = True 

98 parent_select.replace(col, metric_column) 

99 parent_select.from_.relations = [ 

100 ast.Relation(primary=metric_table_expression.child, extensions=joins), 

101 ] 

102 

103 

104def _hoist_metric_source_tables( 

105 session, 

106 table, 

107 metric_select, 

108 metric_table_expression, 

109) -> List[ast.Join]: 

110 """ 

111 Hoist tables in a metric query 

112 we go through all the dep nodes directly in the metric's FROM 

113 we need to surface the node itself to join potential dims 

114 and to surface the node we need to source all its columns 

115 - in the metric for an implicit join 

116 """ 

117 joins = [] 

118 if isinstance(table, ast.Select): 

119 return [] # pragma: no cover 

120 if isinstance(table, ast.Alias): 

121 if isinstance(table.child, ast.Select): # pragma: no cover 

122 return [] # pragma: no cover 

123 table = table.child # pragma: no cover 

124 table_name = table.identifier(False) 

125 if table_node := try_get_dj_node( # pragma: no cover 

126 session, 

127 table_name, 

128 {NodeType.SOURCE, NodeType.TRANSFORM, NodeType.DIMENSION}, 

129 ): 

130 source_cols = [] 

131 for tbl_col in table_node.columns: 

132 source_cols.append(_make_source_columns(tbl_col, table)) 

133 # add the source's columns to the metric projection 

134 # so we can left join hoist the source alongside the metric select 

135 # so that dimensions can join properly in build 

136 metric_select.projection += source_cols 

137 # make the comparison expressions for the left join 

138 # that will hoist the source up 

139 ons = [] 

140 for src_col in source_cols: 

141 ons.append( 

142 _source_column_join_on_expression(src_col, metric_table_expression), 

143 ) 

144 # make the join 

145 if ons: # pragma: no cover 

146 joins.append( 

147 ast.Join( 

148 join_type="LEFT OUTER", 

149 right=table.copy(), 

150 criteria=ast.JoinCriteria(on=ast.BinaryOp.And(*ons)), # type: ignore # pylint: disable=no-value-for-parameter 

151 ), 

152 ) 

153 return joins 

154 

155 

156def _make_source_columns(tbl_col, table) -> ast.Alias[ast.Column]: 

157 """ 

158 Make the source columns for hoisting 

159 """ 

160 temp_col = ast.Column( 

161 ast.Name(tbl_col.name), 

162 _table=table, 

163 as_=True, 

164 ) 

165 return ast.Alias( 

166 ast.Name(amenable_name(str(temp_col))), 

167 child=temp_col, 

168 ) 

169 

170 

171def _source_column_join_on_expression( 

172 src_col, 

173 metric_table_expression, 

174) -> List[ast.BinaryOp]: 

175 """ 

176 Make the part of the ON for the source column 

177 """ 

178 return ast.BinaryOp.Eq( # type: ignore 

179 ast.Column( 

180 src_col.alias_or_name, 

181 _table=metric_table_expression, 

182 ), 

183 src_col.child.copy(), 

184 ) 

185 

186 

187def build_dj_metric_query( # pylint: disable=R0914,R0912 

188 session: Session, 

189 query: str, 

190 dialect: Optional[str] = None, # pylint: disable=unused-argument 

191) -> ast.Query: 

192 """ 

193 Build a dj query in SQL that may include dj metrics 

194 """ 

195 query_ast = parse(query) 

196 select = query_ast.select 

197 # we check all columns looking for metric nodes 

198 for col in select.find_all(ast.Column): 

199 _resolve_metric_nodes(session, col) 

200 

201 return build_ast( 

202 session, 

203 query=ast.Query(select=select), 

204 build_criteria=None, 

205 )