Coverage for src/esxport.py: 98%

132 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-22 17:59 +0530

1"""Main export module.""" 

2from __future__ import annotations 

3 

4import contextlib 

5import json 

6from pathlib import Path 

7from typing import TYPE_CHECKING, Any 

8 

9from elasticsearch.exceptions import ConnectionError 

10from loguru import logger 

11from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential 

12from tqdm import tqdm 

13 

14from src.constant import FLUSH_BUFFER, TIMES_TO_TRY 

15from src.exceptions import FieldNotFoundError, IndexNotFoundError, MetaFieldNotFoundError, ScrollExpiredError 

16from src.strings import index_not_found, meta_field_not_found, output_fields, sorting_by, using_indexes, using_query 

17from src.writer import Writer 

18 

19if TYPE_CHECKING: 

20 from typing_extensions import Self 

21 

22 from src.click_opt.cli_options import CliOptions 

23 from src.elastic import ElasticsearchClient 

24 

25 

26class EsXport(object): 

27 """Main class.""" 

28 

29 def __init__(self: Self, opts: CliOptions, es_client: ElasticsearchClient) -> None: 

30 self.search_args: dict[str, Any] = {} 

31 self.opts = opts 

32 self.num_results = 0 

33 self.scroll_ids: list[str] = [] 

34 self.scroll_time = "30m" 

35 self.rows_written = 0 

36 

37 self.es_client = es_client 

38 

39 @retry( 

40 wait=wait_exponential(2), 

41 stop=stop_after_attempt(TIMES_TO_TRY), 

42 reraise=True, 

43 retry=retry_if_exception_type(ConnectionError), 

44 ) 

45 def _check_indexes(self: Self) -> None: 

46 """Check if input indexes exist.""" 

47 indexes = self.opts.index_prefixes 

48 if "_all" in indexes: 

49 indexes = ["_all"] 

50 else: 

51 indexes_status = self.es_client.indices_exists(index=indexes) 

52 if not indexes_status: 

53 msg = index_not_found.format(", ".join(self.opts.index_prefixes), self.opts.url) 

54 raise IndexNotFoundError( 

55 msg, 

56 ) 

57 self.opts.index_prefixes = indexes 

58 

59 def _validate_fields(self: Self) -> None: 

60 all_fields_dict: dict[str, list[str]] = {} 

61 indices_names = list(self.opts.index_prefixes) 

62 all_expected_fields = self.opts.fields.copy() 

63 for sort_query in self.opts.sort: 

64 sort_key = next(iter(sort_query.keys())) 

65 parts = sort_key.split(".") 

66 sort_param = parts[0] if len(parts) > 0 else sort_key 

67 all_expected_fields.append(sort_param) 

68 if "_all" in all_expected_fields: 

69 all_expected_fields.remove("_all") 

70 

71 for index in indices_names: 

72 response: dict[str, Any] = self.es_client.get_mapping(index=index) 

73 all_fields_dict[index] = [] 

74 for field in response[index]["mappings"]["properties"]: 

75 all_fields_dict[index].append(field) 

76 all_es_fields = {value for values_list in all_fields_dict.values() for value in values_list} 

77 

78 for element in all_expected_fields: 

79 if element not in all_es_fields: 

80 msg = f"Fields {element} doesn't exist in any index." 

81 raise FieldNotFoundError(msg) 

82 

83 def _prepare_search_query(self: Self) -> None: 

84 """Prepares search query from input.""" 

85 self.search_args = { 

86 "index": ",".join(self.opts.index_prefixes), 

87 "scroll": self.scroll_time, 

88 "size": self.opts.scroll_size, 

89 "terminate_after": self.opts.max_results, 

90 "body": self.opts.query, 

91 } 

92 if self.opts.sort: 

93 self.search_args["sort"] = self.opts.sort 

94 

95 if "_all" not in self.opts.fields: 

96 self.search_args["_source_includes"] = ",".join(self.opts.fields) 

97 

98 if self.opts.debug: 

99 logger.debug(using_indexes.format(indexes={", ".join(self.opts.index_prefixes)})) 

100 query = json.dumps(self.opts.query) 

101 logger.debug(using_query.format(query={query})) 

102 logger.debug(output_fields.format(fields={", ".join(self.opts.fields)})) 

103 logger.debug(sorting_by.format(sort=self.opts.sort)) 

104 

105 @retry( 

106 wait=wait_exponential(2), 

107 stop=stop_after_attempt(TIMES_TO_TRY), 

108 reraise=True, 

109 retry=retry_if_exception_type(ConnectionError), 

110 ) 

111 def next_scroll(self: Self, scroll_id: str) -> Any: 

112 """Paginate to the next page.""" 

113 return self.es_client.scroll(scroll=self.scroll_time, scroll_id=scroll_id) 

114 

115 def _write_to_temp_file(self: Self, res: Any) -> None: 

116 """Write to temp file.""" 

117 hit_list = [] 

118 total_size = int(min(self.opts.max_results, self.num_results)) 

119 bar = tqdm( 

120 desc=f"{self.opts.output_file}.tmp", 

121 total=total_size, 

122 unit="docs", 

123 colour="green", 

124 ) 

125 try: 

126 while self.rows_written != total_size: 

127 if res["_scroll_id"] not in self.scroll_ids: 

128 self.scroll_ids.append(res["_scroll_id"]) 

129 

130 for hit in res["hits"]["hits"]: 

131 self.rows_written += 1 

132 bar.update(1) 

133 hit_list.append(hit) 

134 if len(hit_list) == FLUSH_BUFFER: 

135 self._flush_to_file(hit_list) 

136 hit_list = [] 

137 res = self.next_scroll(res["_scroll_id"]) 

138 except ScrollExpiredError: 

139 logger.error("Scroll expired(multiple reads?). Saving loaded data.") 

140 finally: 

141 bar.close() 

142 self._flush_to_file(hit_list) 

143 

144 @retry( 

145 wait=wait_exponential(2), 

146 stop=stop_after_attempt(TIMES_TO_TRY), 

147 reraise=True, 

148 retry=retry_if_exception_type(ConnectionError), 

149 ) 

150 def search_query(self: Self) -> Any: 

151 """Search the index.""" 

152 self._validate_fields() 

153 self._prepare_search_query() 

154 res = self.es_client.search(**self.search_args) 

155 self.num_results = res["hits"]["total"]["value"] 

156 

157 logger.info(f"Found {self.num_results} results.") 

158 

159 if self.num_results > 0: 

160 self._write_to_temp_file(res) 

161 

162 def _flush_to_file(self: Self, hit_list: list[dict[str, Any]]) -> None: 

163 """Flush the search results to a temporary file.""" 

164 

165 def add_meta_fields() -> None: 

166 if self.opts.meta_fields: 

167 for field in self.opts.meta_fields: 

168 try: 

169 data[field] = hit[field] 

170 except KeyError as e: # noqa: PERF203 

171 raise MetaFieldNotFoundError(meta_field_not_found.format(field=field)) from e 

172 

173 with Path(f"{self.opts.output_file}.tmp").open(mode="a", encoding="utf-8") as tmp_file: 

174 for hit in hit_list: 

175 data = hit["_source"] 

176 data.pop("_meta", None) 

177 add_meta_fields() 

178 tmp_file.write(json.dumps(data)) 

179 tmp_file.write("\n") 

180 

181 def _clean_scroll_ids(self: Self) -> None: 

182 """Clear all scroll ids.""" 

183 with contextlib.suppress(Exception): 

184 self.es_client.clear_scroll(scroll_id="_all") 

185 

186 def _extract_headers(self: Self) -> list[str]: 

187 """Extract CSV headers from the first line of the file.""" 

188 with Path(f"{self.opts.output_file}.tmp").open() as f: 

189 first_line = json.loads(f.readline().strip("\n")) 

190 return list(first_line.keys()) 

191 

192 def _export(self: Self) -> None: 

193 """Export the data.""" 

194 headers = self._extract_headers() 

195 kwargs = { 

196 "delimiter": self.opts.delimiter, 

197 "output_format": self.opts.format, 

198 } 

199 Writer.write( 

200 headers=headers, 

201 total_records=self.rows_written, 

202 out_file=self.opts.output_file, 

203 **kwargs, 

204 ) 

205 

206 def export(self: Self) -> None: 

207 """Export the data.""" 

208 self._check_indexes() 

209 self.search_query() 

210 self._clean_scroll_ids() 

211 self._export()