Coverage for src/esxport.py: 98%
132 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-22 17:59 +0530
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-22 17:59 +0530
1"""Main export module."""
2from __future__ import annotations
4import contextlib
5import json
6from pathlib import Path
7from typing import TYPE_CHECKING, Any
9from elasticsearch.exceptions import ConnectionError
10from loguru import logger
11from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
12from tqdm import tqdm
14from src.constant import FLUSH_BUFFER, TIMES_TO_TRY
15from src.exceptions import FieldNotFoundError, IndexNotFoundError, MetaFieldNotFoundError, ScrollExpiredError
16from src.strings import index_not_found, meta_field_not_found, output_fields, sorting_by, using_indexes, using_query
17from src.writer import Writer
19if TYPE_CHECKING:
20 from typing_extensions import Self
22 from src.click_opt.cli_options import CliOptions
23 from src.elastic import ElasticsearchClient
26class EsXport(object):
27 """Main class."""
29 def __init__(self: Self, opts: CliOptions, es_client: ElasticsearchClient) -> None:
30 self.search_args: dict[str, Any] = {}
31 self.opts = opts
32 self.num_results = 0
33 self.scroll_ids: list[str] = []
34 self.scroll_time = "30m"
35 self.rows_written = 0
37 self.es_client = es_client
39 @retry(
40 wait=wait_exponential(2),
41 stop=stop_after_attempt(TIMES_TO_TRY),
42 reraise=True,
43 retry=retry_if_exception_type(ConnectionError),
44 )
45 def _check_indexes(self: Self) -> None:
46 """Check if input indexes exist."""
47 indexes = self.opts.index_prefixes
48 if "_all" in indexes:
49 indexes = ["_all"]
50 else:
51 indexes_status = self.es_client.indices_exists(index=indexes)
52 if not indexes_status:
53 msg = index_not_found.format(", ".join(self.opts.index_prefixes), self.opts.url)
54 raise IndexNotFoundError(
55 msg,
56 )
57 self.opts.index_prefixes = indexes
59 def _validate_fields(self: Self) -> None:
60 all_fields_dict: dict[str, list[str]] = {}
61 indices_names = list(self.opts.index_prefixes)
62 all_expected_fields = self.opts.fields.copy()
63 for sort_query in self.opts.sort:
64 sort_key = next(iter(sort_query.keys()))
65 parts = sort_key.split(".")
66 sort_param = parts[0] if len(parts) > 0 else sort_key
67 all_expected_fields.append(sort_param)
68 if "_all" in all_expected_fields:
69 all_expected_fields.remove("_all")
71 for index in indices_names:
72 response: dict[str, Any] = self.es_client.get_mapping(index=index)
73 all_fields_dict[index] = []
74 for field in response[index]["mappings"]["properties"]:
75 all_fields_dict[index].append(field)
76 all_es_fields = {value for values_list in all_fields_dict.values() for value in values_list}
78 for element in all_expected_fields:
79 if element not in all_es_fields:
80 msg = f"Fields {element} doesn't exist in any index."
81 raise FieldNotFoundError(msg)
83 def _prepare_search_query(self: Self) -> None:
84 """Prepares search query from input."""
85 self.search_args = {
86 "index": ",".join(self.opts.index_prefixes),
87 "scroll": self.scroll_time,
88 "size": self.opts.scroll_size,
89 "terminate_after": self.opts.max_results,
90 "body": self.opts.query,
91 }
92 if self.opts.sort:
93 self.search_args["sort"] = self.opts.sort
95 if "_all" not in self.opts.fields:
96 self.search_args["_source_includes"] = ",".join(self.opts.fields)
98 if self.opts.debug:
99 logger.debug(using_indexes.format(indexes={", ".join(self.opts.index_prefixes)}))
100 query = json.dumps(self.opts.query)
101 logger.debug(using_query.format(query={query}))
102 logger.debug(output_fields.format(fields={", ".join(self.opts.fields)}))
103 logger.debug(sorting_by.format(sort=self.opts.sort))
105 @retry(
106 wait=wait_exponential(2),
107 stop=stop_after_attempt(TIMES_TO_TRY),
108 reraise=True,
109 retry=retry_if_exception_type(ConnectionError),
110 )
111 def next_scroll(self: Self, scroll_id: str) -> Any:
112 """Paginate to the next page."""
113 return self.es_client.scroll(scroll=self.scroll_time, scroll_id=scroll_id)
115 def _write_to_temp_file(self: Self, res: Any) -> None:
116 """Write to temp file."""
117 hit_list = []
118 total_size = int(min(self.opts.max_results, self.num_results))
119 bar = tqdm(
120 desc=f"{self.opts.output_file}.tmp",
121 total=total_size,
122 unit="docs",
123 colour="green",
124 )
125 try:
126 while self.rows_written != total_size:
127 if res["_scroll_id"] not in self.scroll_ids:
128 self.scroll_ids.append(res["_scroll_id"])
130 for hit in res["hits"]["hits"]:
131 self.rows_written += 1
132 bar.update(1)
133 hit_list.append(hit)
134 if len(hit_list) == FLUSH_BUFFER:
135 self._flush_to_file(hit_list)
136 hit_list = []
137 res = self.next_scroll(res["_scroll_id"])
138 except ScrollExpiredError:
139 logger.error("Scroll expired(multiple reads?). Saving loaded data.")
140 finally:
141 bar.close()
142 self._flush_to_file(hit_list)
144 @retry(
145 wait=wait_exponential(2),
146 stop=stop_after_attempt(TIMES_TO_TRY),
147 reraise=True,
148 retry=retry_if_exception_type(ConnectionError),
149 )
150 def search_query(self: Self) -> Any:
151 """Search the index."""
152 self._validate_fields()
153 self._prepare_search_query()
154 res = self.es_client.search(**self.search_args)
155 self.num_results = res["hits"]["total"]["value"]
157 logger.info(f"Found {self.num_results} results.")
159 if self.num_results > 0:
160 self._write_to_temp_file(res)
162 def _flush_to_file(self: Self, hit_list: list[dict[str, Any]]) -> None:
163 """Flush the search results to a temporary file."""
165 def add_meta_fields() -> None:
166 if self.opts.meta_fields:
167 for field in self.opts.meta_fields:
168 try:
169 data[field] = hit[field]
170 except KeyError as e: # noqa: PERF203
171 raise MetaFieldNotFoundError(meta_field_not_found.format(field=field)) from e
173 with Path(f"{self.opts.output_file}.tmp").open(mode="a", encoding="utf-8") as tmp_file:
174 for hit in hit_list:
175 data = hit["_source"]
176 data.pop("_meta", None)
177 add_meta_fields()
178 tmp_file.write(json.dumps(data))
179 tmp_file.write("\n")
181 def _clean_scroll_ids(self: Self) -> None:
182 """Clear all scroll ids."""
183 with contextlib.suppress(Exception):
184 self.es_client.clear_scroll(scroll_id="_all")
186 def _extract_headers(self: Self) -> list[str]:
187 """Extract CSV headers from the first line of the file."""
188 with Path(f"{self.opts.output_file}.tmp").open() as f:
189 first_line = json.loads(f.readline().strip("\n"))
190 return list(first_line.keys())
192 def _export(self: Self) -> None:
193 """Export the data."""
194 headers = self._extract_headers()
195 kwargs = {
196 "delimiter": self.opts.delimiter,
197 "output_format": self.opts.format,
198 }
199 Writer.write(
200 headers=headers,
201 total_records=self.rows_written,
202 out_file=self.opts.output_file,
203 **kwargs,
204 )
206 def export(self: Self) -> None:
207 """Export the data."""
208 self._check_indexes()
209 self.search_query()
210 self._clean_scroll_ids()
211 self._export()