Coverage for dj/api/data.py: 100%

43 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1""" 

2Data related APIs. 

3""" 

4 

5import logging 

6from typing import List, Optional 

7 

8from fastapi import APIRouter, Depends, Query 

9from fastapi.responses import JSONResponse 

10from sqlmodel import Session 

11 

12from dj.api.helpers import get_engine, get_node_by_name, get_query 

13from dj.errors import DJException, DJInvalidInputException 

14from dj.models.metric import TranslatedSQL 

15from dj.models.node import AvailabilityState, AvailabilityStateBase, NodeType 

16from dj.models.query import ColumnMetadata, QueryCreate, QueryWithResults 

17from dj.service_clients import QueryServiceClient 

18from dj.utils import get_query_service_client, get_session 

19 

20_logger = logging.getLogger(__name__) 

21router = APIRouter() 

22 

23 

24@router.post("/data/{node_name}/availability/") 

25def add_an_availability_state( 

26 node_name: str, 

27 data: AvailabilityStateBase, 

28 *, 

29 session: Session = Depends(get_session), 

30) -> JSONResponse: 

31 """ 

32 Add an availability state to a node 

33 """ 

34 node = get_node_by_name(session, node_name) 

35 

36 # Source nodes require that any availability states set are for one of the defined tables 

37 node_revision = node.current 

38 if data.catalog != node_revision.catalog.name: 

39 raise DJException( 

40 "Cannot set availability state in different catalog: " 

41 f"{data.catalog}, {node_revision.catalog}", 

42 ) 

43 if node.current.type == NodeType.SOURCE: 

44 if node_revision.schema_ != data.schema_ or node_revision.table != data.table: 

45 raise DJException( 

46 message=( 

47 "Cannot set availability state, " 

48 "source nodes require availability " 

49 "states to match the set table: " 

50 f"{data.catalog}." 

51 f"{data.schema_}." 

52 f"{data.table} " 

53 "does not match " 

54 f"{node_revision.catalog.name}." 

55 f"{node_revision.schema_}." 

56 f"{node_revision.table} " 

57 ), 

58 ) 

59 

60 # Merge the new availability state with the current availability state if one exists 

61 if ( 

62 node_revision.availability 

63 and node_revision.availability.catalog == node.current.catalog.name 

64 and node_revision.availability.schema_ == data.schema_ 

65 and node_revision.availability.table == data.table 

66 ): 

67 # Currently, we do not consider type information. We should eventually check the type of 

68 # the partition values in order to cast them before sorting. 

69 data.max_partition = max( 

70 ( 

71 node_revision.availability.max_partition, 

72 data.max_partition, 

73 ), 

74 ) 

75 data.min_partition = min( 

76 ( 

77 node_revision.availability.min_partition, 

78 data.min_partition, 

79 ), 

80 ) 

81 

82 db_new_availability = AvailabilityState.from_orm(data) 

83 node_revision.availability = db_new_availability 

84 session.add(node_revision) 

85 session.commit() 

86 return JSONResponse( 

87 status_code=200, 

88 content={"message": "Availability state successfully posted"}, 

89 ) 

90 

91 

92@router.get("/data/{node_name}/") 

93def get_data( # pylint: disable=too-many-locals 

94 node_name: str, 

95 *, 

96 dimensions: List[str] = Query([]), 

97 filters: List[str] = Query([]), 

98 async_: bool = False, 

99 session: Session = Depends(get_session), 

100 query_service_client: QueryServiceClient = Depends(get_query_service_client), 

101 engine_name: Optional[str] = None, 

102 engine_version: Optional[str] = None, 

103) -> QueryWithResults: 

104 """ 

105 Gets data for a node 

106 """ 

107 node = get_node_by_name(session, node_name) 

108 

109 available_engines = node.current.catalog.engines 

110 engine = ( 

111 get_engine(session, engine_name, engine_version) # type: ignore 

112 if engine_name 

113 else available_engines[0] 

114 ) 

115 if engine not in available_engines: 

116 raise DJInvalidInputException( # pragma: no cover 

117 f"The selected engine is not available for the node {node_name}. " 

118 f"Available engines include: {', '.join(engine.name for engine in available_engines)}", 

119 ) 

120 

121 query_ast = get_query( 

122 session=session, 

123 node_name=node_name, 

124 dimensions=dimensions, 

125 filters=filters, 

126 engine=engine, 

127 ) 

128 columns = [ 

129 ColumnMetadata(name=col.alias_or_name.name, type=str(col.type)) # type: ignore 

130 for col in query_ast.select.projection 

131 ] 

132 query = TranslatedSQL( 

133 sql=str(query_ast), 

134 columns=columns, 

135 ) 

136 

137 query_create = QueryCreate( 

138 engine_name=engine.name, 

139 catalog_name=node.current.catalog.name, 

140 engine_version=engine.version, 

141 submitted_query=query.sql, 

142 async_=async_, 

143 ) 

144 result = query_service_client.submit_query(query_create) 

145 # Inject column info if there are results 

146 if result.results.__root__: # pragma: no cover 

147 result.results.__root__[0].columns = columns 

148 return result