Coverage for dj/api/nodes.py: 100%
298 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
1"""
2Node related APIs.
3"""
4import http.client
5import logging
6import os
7from collections import defaultdict
8from http import HTTPStatus
9from typing import List, Optional, Union
11from fastapi import APIRouter, Depends
12from fastapi.responses import JSONResponse
13from sqlalchemy.orm import joinedload
14from sqlmodel import Session, select
15from starlette.requests import Request
16from starlette.responses import Response
18from dj.api.helpers import (
19 get_attribute_type,
20 get_catalog,
21 get_column,
22 get_downstream_nodes,
23 get_engine,
24 get_node_by_name,
25 get_node_namespace,
26 propagate_valid_status,
27 raise_if_node_exists,
28 resolve_downstream_references,
29 validate_node_data,
30)
31from dj.api.tags import get_tag_by_name
32from dj.errors import DJDoesNotExistException, DJException, DJInvalidInputException
33from dj.models import ColumnAttribute
34from dj.models.attribute import UniquenessScope
35from dj.models.base import generate_display_name
36from dj.models.column import Column, ColumnAttributeInput
37from dj.models.node import (
38 DEFAULT_DRAFT_VERSION,
39 DEFAULT_PUBLISHED_VERSION,
40 ColumnOutput,
41 CreateCubeNode,
42 CreateNode,
43 CreateSourceNode,
44 MaterializationConfig,
45 MissingParent,
46 Node,
47 NodeMode,
48 NodeOutput,
49 NodeRevision,
50 NodeRevisionBase,
51 NodeRevisionOutput,
52 NodeStatus,
53 NodeType,
54 NodeValidation,
55 UpdateNode,
56 UpsertMaterializationConfig,
57)
58from dj.service_clients import QueryServiceClient
59from dj.sql.parsing.backends.antlr4 import parse
60from dj.utils import (
61 Version,
62 VersionUpgrade,
63 get_namespace_from_name,
64 get_query_service_client,
65 get_session,
66)
68_logger = logging.getLogger(__name__)
69router = APIRouter()
72@router.post("/nodes/validate/", response_model=NodeValidation)
73def validate_a_node(
74 data: Union[NodeRevisionBase, NodeRevision],
75 session: Session = Depends(get_session),
76) -> NodeValidation:
77 """
78 Validate a node.
79 """
81 if data.type == NodeType.SOURCE:
82 raise DJException(message="Source nodes cannot be validated")
84 (
85 validated_node,
86 dependencies_map,
87 missing_parents_map,
88 type_inference_failed_columns,
89 ) = validate_node_data(data, session)
90 if missing_parents_map or type_inference_failed_columns:
91 status = NodeStatus.INVALID
92 else:
93 status = NodeStatus.VALID
95 return NodeValidation(
96 message=f"Node `{validated_node.name}` is {status}",
97 status=status,
98 node_revision=validated_node,
99 dependencies=set(dependencies_map.keys()),
100 columns=validated_node.columns,
101 )
104def validate_and_build_attribute(
105 session: Session,
106 attribute_input: ColumnAttributeInput,
107 node: Node,
108) -> ColumnAttribute:
109 """
110 Run some validation and build column attribute.
111 """
112 column_map = {column.name: column for column in node.current.columns}
113 if attribute_input.column_name not in column_map:
114 raise DJDoesNotExistException(
115 message=f"Column `{attribute_input.column_name}` "
116 f"does not exist on node `{node.name}`!",
117 )
118 column = column_map[attribute_input.column_name]
119 existing_attributes = {attr.attribute_type.name: attr for attr in column.attributes}
120 if attribute_input.attribute_type_name in existing_attributes:
121 return existing_attributes[attribute_input.attribute_type_name]
123 # Verify attribute type exists
124 attribute_type = get_attribute_type(
125 session,
126 attribute_input.attribute_type_name,
127 attribute_input.attribute_type_namespace,
128 )
129 if not attribute_type:
130 raise DJDoesNotExistException(
131 message=f"Attribute type `{attribute_input.attribute_type_namespace}"
132 f".{attribute_input.attribute_type_name}` "
133 f"does not exist!",
134 )
136 # Verify that the attribute type is allowed for this node
137 if node.type not in attribute_type.allowed_node_types:
138 raise DJException(
139 message=f"Attribute type `{attribute_input.attribute_type_namespace}."
140 f"{attribute_type.name}` not allowed on node "
141 f"type `{node.type}`!",
142 )
144 return ColumnAttribute(
145 attribute_type=attribute_type,
146 column=column,
147 )
150def set_column_attributes_on_node(
151 session: Session,
152 attributes: List[ColumnAttributeInput],
153 node: Node,
154) -> List[Column]:
155 """
156 Sets the column attributes on the node if allowed.
157 """
158 modified_columns_map = {}
159 for attribute_input in attributes:
160 new_attribute = validate_and_build_attribute(session, attribute_input, node)
161 # pylint: disable=no-member
162 modified_columns_map[new_attribute.column.name] = new_attribute.column
164 # Validate column attributes by building mapping between
165 # attribute scope and columns
166 attributes_columns_map = defaultdict(set)
167 modified_columns = modified_columns_map.values()
169 for column in modified_columns:
170 for attribute in column.attributes:
171 scopes_map = {
172 UniquenessScope.NODE: attribute.attribute_type,
173 UniquenessScope.COLUMN_TYPE: column.type,
174 }
175 attributes_columns_map[
176 ( # type: ignore
177 attribute.attribute_type,
178 tuple(
179 scopes_map[item]
180 for item in attribute.attribute_type.uniqueness_scope
181 ),
182 )
183 ].add(column.name)
185 for (attribute, _), columns in attributes_columns_map.items():
186 if len(columns) > 1 and attribute.uniqueness_scope:
187 for col in columns:
188 modified_columns_map[col].attributes = []
189 raise DJException(
190 message=f"The column attribute `{attribute.name}` is scoped to be "
191 f"unique to the `{attribute.uniqueness_scope}` level, but there "
192 "is more than one column tagged with it: "
193 f"`{', '.join(sorted(list(columns)))}`",
194 )
196 session.add_all(modified_columns)
197 session.commit()
198 for col in modified_columns:
199 session.refresh(col)
201 session.refresh(node)
202 session.refresh(node.current)
203 return list(modified_columns)
206@router.post(
207 "/nodes/{node_name}/attributes/",
208 response_model=List[ColumnOutput],
209 status_code=201,
210)
211def set_column_attributes(
212 node_name: str,
213 attributes: List[ColumnAttributeInput],
214 *,
215 session: Session = Depends(get_session),
216) -> List[ColumnOutput]:
217 """
218 Set column attributes for the node.
219 """
220 node = get_node_by_name(session, node_name)
221 modified_columns = set_column_attributes_on_node(session, attributes, node)
222 return list(modified_columns) # type: ignore
225@router.get("/nodes/", response_model=List[NodeOutput])
226def list_nodes(*, session: Session = Depends(get_session)) -> List[NodeOutput]:
227 """
228 List the available nodes.
229 """
230 nodes = session.exec(select(Node).options(joinedload(Node.current))).unique().all()
231 return nodes
234@router.get("/nodes/{name}/", response_model=NodeOutput)
235def get_a_node(name: str, *, session: Session = Depends(get_session)) -> NodeOutput:
236 """
237 Show the active version of the specified node.
238 """
239 node = get_node_by_name(session, name, with_current=True)
240 return node # type: ignore
243@router.delete("/nodes/{name}/", status_code=204)
244def delete_a_node(name: str, *, session: Session = Depends(get_session)):
245 """
246 Delete the specified node.
247 """
248 node = get_node_by_name(session, name, with_current=True)
250 # Find all downstream nodes and mark them as invalid
251 downstreams = get_downstream_nodes(session, node.name)
252 for downstream in downstreams:
253 downstream.current.status = NodeStatus.INVALID
254 session.add(downstream)
256 # If the node is a dimension, find all columns that
257 # are linked to this dimension and remove the link
258 if node.type == NodeType.DIMENSION:
259 columns = (
260 session.exec(select(Column).where(Column.dimension_id == node.id))
261 .unique()
262 .all()
263 )
264 for col in columns:
265 col.dimension_id = None
266 col.dimension_column = None
267 session.add(col)
268 session.delete(node)
269 session.commit()
270 return Response(status_code=HTTPStatus.NO_CONTENT.value)
273@router.post("/nodes/{name}/materialization/", status_code=201)
274def upsert_a_materialization_config(
275 name: str,
276 data: UpsertMaterializationConfig,
277 *,
278 session: Session = Depends(get_session),
279) -> JSONResponse:
280 """
281 Update materialization config of the specified node.
282 """
283 node = get_node_by_name(session, name, with_current=True)
284 if node.type == NodeType.SOURCE:
285 raise DJException(
286 http_status_code=HTTPStatus.BAD_REQUEST,
287 message=f"Cannot set materialization config for source node `{name}`!",
288 )
289 current_revision = node.current
291 # Check to see if a config for this engine already exists with the exact same config
292 existing_config_for_engine = [
293 config
294 for config in node.current.materialization_configs
295 if config.engine.name == data.engine_name
296 ]
297 if (
298 existing_config_for_engine
299 and existing_config_for_engine[0].config == data.config
300 ):
301 return JSONResponse(
302 status_code=HTTPStatus.NO_CONTENT,
303 content={
304 "message": (
305 f"The same materialization config provided already exists for "
306 f"node `{name}` so no update was performed."
307 ),
308 },
309 )
311 # Materialization config changed, so create a new materialization config and a new node
312 # revision that references it.
313 engine = get_engine(session, data.engine_name, data.engine_version)
314 new_node_revision = create_new_revision_from_existing(
315 session,
316 current_revision,
317 node,
318 version_upgrade=VersionUpgrade.MAJOR,
319 )
321 unchanged_existing_configs = [
322 config
323 for config in node.current.materialization_configs
324 if config.engine.name != data.engine_name
325 ]
326 new_config = MaterializationConfig(
327 node_revision=new_node_revision,
328 engine=engine,
329 config=data.config,
330 )
331 new_node_revision.materialization_configs = unchanged_existing_configs + [ # type: ignore
332 new_config,
333 ]
334 node.current_version = new_node_revision.version # type: ignore
336 # This will add the materialization config, the new node rev, and update the node's version.
337 session.add(new_node_revision)
338 session.add(node)
339 session.commit()
341 return JSONResponse(
342 status_code=200,
343 content={
344 "message": (
345 f"Successfully updated materialization config for node `{name}`"
346 f" and engine `{engine.name}`."
347 ),
348 },
349 )
352@router.get("/nodes/{name}/revisions/", response_model=List[NodeRevisionOutput])
353def list_node_revisions(
354 name: str, *, session: Session = Depends(get_session)
355) -> List[NodeRevisionOutput]:
356 """
357 List all revisions for the node.
358 """
359 node = get_node_by_name(session, name, with_current=False)
360 return node.revisions # type: ignore
363def create_node_revision(
364 data: CreateNode,
365 node_type: NodeType,
366 session: Session,
367) -> NodeRevision:
368 """
369 Create a non-source node revision.
370 """
371 node_revision = NodeRevision(
372 name=data.name,
373 namespace=data.namespace,
374 display_name=data.display_name
375 if data.display_name
376 else generate_display_name(data.name),
377 description=data.description,
378 type=node_type,
379 status=NodeStatus.VALID,
380 query=data.query,
381 mode=data.mode,
382 )
383 (
384 validated_node,
385 dependencies_map,
386 missing_parents_map,
387 type_inference_failed_columns,
388 ) = validate_node_data(node_revision, session)
389 if missing_parents_map or type_inference_failed_columns:
390 node_revision.status = NodeStatus.INVALID
391 else:
392 node_revision.status = NodeStatus.VALID
393 node_revision.missing_parents = [
394 MissingParent(name=missing_parent) for missing_parent in missing_parents_map
395 ]
396 new_parents = [node.name for node in dependencies_map]
397 catalog_ids = [node.catalog_id for node in dependencies_map]
398 if node_revision.mode == NodeMode.PUBLISHED and not len(set(catalog_ids)) == 1:
399 raise DJException(
400 f"Cannot create nodes with multi-catalog dependencies: {set(catalog_ids)}",
401 )
402 catalog_id = next(iter(catalog_ids), 0)
403 parent_refs = session.exec(
404 select(Node).where(
405 # pylint: disable=no-member
406 Node.name.in_( # type: ignore
407 new_parents,
408 ),
409 ),
410 ).all()
411 node_revision.parents = parent_refs
413 _logger.info(
414 "Parent nodes for %s (%s): %s",
415 data.name,
416 node_revision.version,
417 [p.name for p in node_revision.parents],
418 )
419 node_revision.columns = validated_node.columns or []
420 node_revision.catalog_id = catalog_id
421 return node_revision
424def create_cube_node_revision(
425 session: Session,
426 data: CreateCubeNode,
427) -> NodeRevision:
428 """
429 Create a cube node revision.
430 """
431 metrics = []
432 dimensions = []
433 catalogs = []
434 for node_name in data.cube_elements:
435 cube_element = get_node_by_name(session=session, name=node_name)
436 catalogs.append(cube_element.current.catalog.name)
437 if cube_element.type == NodeType.METRIC:
438 metrics.append(cube_element)
439 elif cube_element.type == NodeType.DIMENSION:
440 dimensions.append(cube_element)
441 else:
442 raise DJException(
443 message=(
444 f"Node {cube_element.name} of type {cube_element.type} "
445 "cannot be added to a cube"
446 ),
447 http_status_code=http.client.UNPROCESSABLE_ENTITY,
448 )
449 if not metrics:
450 raise DJException(
451 message=("At least one metric is required to create a cube node"),
452 http_status_code=http.client.UNPROCESSABLE_ENTITY,
453 )
454 if not dimensions:
455 raise DJException(
456 message=("At least one dimension is required to create a cube node"),
457 http_status_code=http.client.UNPROCESSABLE_ENTITY,
458 )
459 if len(set(catalogs)) > 1:
460 raise DJException(
461 message=(
462 f"Cannot create cube using nodes from multiple catalogs: {catalogs}"
463 ),
464 )
465 if len(set(catalogs)) < 1: # pragma: no cover
466 raise DJException(
467 message=("Cube elements must contain a common catalog"),
468 )
469 return NodeRevision(
470 name=data.name,
471 namespace=data.namespace,
472 description=data.description,
473 type=NodeType.CUBE,
474 cube_elements=metrics + dimensions,
475 )
478def save_node(
479 session: Session,
480 node_revision: NodeRevision,
481 node: Node,
482 node_mode: NodeMode,
483):
484 """
485 Links the node and node revision together and saves them
486 """
487 node_revision.node = node
488 node_revision.version = (
489 str(DEFAULT_DRAFT_VERSION)
490 if node_mode == NodeMode.DRAFT
491 else str(DEFAULT_PUBLISHED_VERSION)
492 )
493 node.current_version = node_revision.version
494 node_revision.extra_validation()
496 session.add(node)
497 session.commit()
499 newly_valid_nodes = resolve_downstream_references(
500 session=session,
501 node_revision=node_revision,
502 )
503 propagate_valid_status(
504 session=session,
505 valid_nodes=newly_valid_nodes,
506 catalog_id=node.current.catalog_id, # pylint: disable=no-member
507 )
508 session.refresh(node.current)
511@router.post("/nodes/source/", response_model=NodeOutput, status_code=201)
512def create_a_source(
513 data: CreateSourceNode,
514 session: Session = Depends(get_session),
515 query_service_client: QueryServiceClient = Depends(get_query_service_client),
516) -> NodeOutput:
517 """
518 Create a source node. If columns are not provided, the source node's schema
519 will be inferred using the configured query service.
520 """
521 raise_if_node_exists(session, data.name)
523 # Extract and assign namespace if one exists
524 namespace = get_namespace_from_name(data.name)
525 get_node_namespace(
526 session=session,
527 namespace=namespace,
528 ) # Will return 404 if namespace doesn't exist
529 data.namespace = namespace
531 node = Node(
532 name=data.name,
533 namespace=data.namespace,
534 type=NodeType.SOURCE,
535 current_version=0,
536 )
537 catalog = get_catalog(session=session, name=data.catalog)
539 # When no columns are provided, attempt to find actual table columns
540 # if a query service is set
541 columns = (
542 [
543 Column(
544 name=column_data.name,
545 type=column_data.type,
546 dimension=(
547 get_node_by_name(
548 session,
549 name=column_data.dimension,
550 node_type=NodeType.DIMENSION,
551 raise_if_not_exists=False,
552 )
553 ),
554 )
555 for column_data in data.columns
556 ]
557 if data.columns
558 else None
559 )
560 if not columns:
561 if not query_service_client:
562 raise DJException(
563 message="No table columns were provided and no query "
564 "service is configured for table columns inference!",
565 )
566 columns = query_service_client.get_columns_for_table(
567 data.catalog,
568 data.schema_, # type: ignore
569 data.table,
570 catalog.engines[0] if len(catalog.engines) >= 1 else None,
571 )
573 node_revision = NodeRevision(
574 name=data.name,
575 namespace=data.namespace,
576 display_name=data.display_name
577 if data.display_name
578 else generate_display_name(data.name),
579 description=data.description,
580 type=NodeType.SOURCE,
581 status=NodeStatus.VALID,
582 catalog_id=catalog.id,
583 schema_=data.schema_,
584 table=data.table,
585 columns=columns,
586 parents=[],
587 )
589 # Point the node to the new node revision.
590 save_node(session, node_revision, node, data.mode)
591 return node # type: ignore
594@router.post("/nodes/transform/", response_model=NodeOutput, status_code=201)
595@router.post("/nodes/dimension/", response_model=NodeOutput, status_code=201)
596@router.post("/nodes/metric/", response_model=NodeOutput, status_code=201)
597def create_a_node(
598 data: CreateNode,
599 request: Request,
600 *,
601 session: Session = Depends(get_session),
602) -> NodeOutput:
603 """
604 Create a node.
605 """
606 node_type = NodeType(os.path.basename(os.path.normpath(request.url.path)))
608 if node_type == NodeType.DIMENSION and not data.primary_key:
609 raise DJInvalidInputException("Dimension nodes must define a primary key!")
611 raise_if_node_exists(session, data.name)
613 namespace = get_namespace_from_name(data.name)
614 get_node_namespace(
615 session=session,
616 namespace=namespace,
617 ) # Will return 404 if namespace doesn't exist
618 data.namespace = namespace
620 node = Node(
621 name=data.name,
622 namespace=data.namespace,
623 type=NodeType(node_type),
624 current_version=0,
625 )
626 node_revision = create_node_revision(data, node_type, session)
627 save_node(session, node_revision, node, data.mode)
628 session.refresh(node)
630 column_names = {col.name for col in node_revision.columns}
631 if data.primary_key and any(
632 key_column not in column_names for key_column in data.primary_key
633 ):
634 raise DJInvalidInputException(
635 f"Some columns in the primary key {','.join(data.primary_key)} "
636 f"were not found in the list of available columns for the node {node.name}.",
637 )
638 if data.primary_key:
639 attributes = [
640 ColumnAttributeInput(
641 attribute_type_namespace="system",
642 attribute_type_name="primary_key",
643 column_name=key_column,
644 )
645 for key_column in data.primary_key
646 if key_column in column_names
647 ]
648 set_column_attributes_on_node(session, attributes, node)
649 session.refresh(node)
650 session.refresh(node.current)
651 return node # type: ignore
654@router.post("/nodes/cube/", response_model=NodeOutput, status_code=201)
655def create_a_cube(
656 data: CreateCubeNode,
657 session: Session = Depends(get_session),
658) -> NodeOutput:
659 """
660 Create a node.
661 """
662 raise_if_node_exists(session, data.name)
663 node = Node(
664 name=data.name,
665 namespace=data.namespace,
666 type=NodeType.CUBE,
667 current_version=0,
668 )
669 node_revision = create_cube_node_revision(session=session, data=data)
670 save_node(session, node_revision, node, data.mode)
671 return node # type: ignore
674@router.post("/nodes/{name}/columns/{column}/", status_code=201)
675def link_a_dimension(
676 name: str,
677 column: str,
678 dimension: Optional[str] = None,
679 dimension_column: Optional[str] = None,
680 session: Session = Depends(get_session),
681) -> JSONResponse:
682 """
683 Add information to a node column
684 """
685 if not dimension: # If no dimension is set, assume it matches the column name
686 dimension = column
688 node = get_node_by_name(session=session, name=name)
689 dimension_node = get_node_by_name(
690 session=session,
691 name=dimension,
692 node_type=NodeType.DIMENSION,
693 )
694 if node.current.catalog.name != dimension_node.current.catalog.name:
695 raise DJException(
696 message=(
697 "Cannot add dimension to column, because catalogs do not match: "
698 f"{node.current.catalog.name}, {dimension_node.current.catalog.name}"
699 ),
700 )
702 target_column = get_column(node.current, column)
703 if dimension_column:
704 # Check that the dimension column exists
705 column_from_dimension = get_column(dimension_node.current, dimension_column)
707 # Check the dimension column's type is compatible with the target column's type
708 if not column_from_dimension.type.is_compatible(target_column.type):
709 raise DJInvalidInputException(
710 f"The column {target_column.name} has type {target_column.type} "
711 f"and is being linked to the dimension {dimension} via the dimension"
712 f" column {dimension_column}, which has type {column_from_dimension.type}."
713 " These column types are incompatible and the dimension cannot be linked!",
714 )
716 target_column.dimension = dimension_node
717 target_column.dimension_id = dimension_node.id
718 target_column.dimension_column = dimension_column
720 session.add(node)
721 session.commit()
722 session.refresh(node)
723 return JSONResponse(
724 status_code=201,
725 content={
726 "message": (
727 f"Dimension node {dimension} has been successfully "
728 f"linked to column {column} on node {name}"
729 ),
730 },
731 )
734@router.post("/nodes/{name}/tag/", status_code=201)
735def tag_a_node(
736 name: str, tag_name: str, *, session: Session = Depends(get_session)
737) -> JSONResponse:
738 """
739 Add a tag to a node
740 """
741 node = get_node_by_name(session=session, name=name)
742 tag = get_tag_by_name(session, name=tag_name, raise_if_not_exists=True)
743 node.tags.append(tag)
745 session.add(node)
746 session.commit()
747 session.refresh(node)
748 session.refresh(tag)
750 return JSONResponse(
751 status_code=201,
752 content={
753 "message": (
754 f"Node `{name}` has been successfully tagged with tag `{tag_name}`"
755 ),
756 },
757 )
760def create_new_revision_from_existing( # pylint: disable=too-many-locals
761 session: Session,
762 old_revision: NodeRevision,
763 node: Node,
764 data: UpdateNode = None,
765 version_upgrade: VersionUpgrade = None,
766) -> Optional[NodeRevision]:
767 """
768 Creates a new revision from an existing node revision.
769 """
770 minor_changes = (
771 (data and data.description and old_revision.description != data.description)
772 or (data and data.mode and old_revision.mode != data.mode)
773 or (
774 data
775 and data.display_name
776 and old_revision.display_name != data.display_name
777 )
778 )
779 query_changes = (
780 old_revision.type != NodeType.SOURCE
781 and data
782 and data.query
783 and old_revision.query != data.query
784 )
785 column_changes = (
786 old_revision.type == NodeType.SOURCE
787 and data is not None
788 and data.columns is not None
789 and ({col.identifier() for col in old_revision.columns} != data.columns)
790 )
791 major_changes = query_changes or column_changes
793 # If nothing has changed, do not create the new node revision
794 if not minor_changes and not major_changes and not version_upgrade:
795 return None
797 old_version = Version.parse(node.current_version)
798 new_revision = NodeRevision(
799 name=old_revision.name,
800 node_id=node.id,
801 version=str(
802 old_version.next_major_version()
803 if major_changes or version_upgrade == VersionUpgrade.MAJOR
804 else old_version.next_minor_version(),
805 ),
806 display_name=(
807 data.display_name
808 if data and data.display_name
809 else old_revision.display_name
810 ),
811 description=(
812 data.description if data and data.description else old_revision.description
813 ),
814 query=(data.query if data and data.query else old_revision.query),
815 type=old_revision.type,
816 columns=[
817 Column(
818 name=column_data.name,
819 type=column_data.type,
820 dimension_column=column_data.dimension,
821 )
822 for column_data in data.columns
823 ]
824 if data and data.columns
825 else old_revision.columns,
826 catalog=old_revision.catalog,
827 schema_=old_revision.schema_,
828 table=old_revision.table,
829 parents=[],
830 mode=data.mode if data and data.mode else old_revision.mode,
831 materialization_configs=old_revision.materialization_configs,
832 )
834 # Link the new revision to its parents if the query has changed
835 if (
836 new_revision.type != NodeType.SOURCE
837 and new_revision.query != old_revision.query
838 ):
839 (
840 validated_node,
841 dependencies_map,
842 missing_parents_map,
843 type_inference_failed_columns,
844 ) = validate_node_data(new_revision, session)
845 new_parents = [n.name for n in dependencies_map]
846 parent_refs = session.exec(
847 select(Node).where(
848 # pylint: disable=no-member
849 Node.name.in_( # type: ignore
850 new_parents,
851 ),
852 ),
853 ).all()
854 new_revision.parents = list(parent_refs)
855 if missing_parents_map or type_inference_failed_columns:
856 new_revision.status = NodeStatus.INVALID
857 else:
858 new_revision.status = NodeStatus.VALID
859 new_revision.missing_parents = [
860 MissingParent(name=missing_parent) for missing_parent in missing_parents_map
861 ]
862 _logger.info(
863 "Parent nodes for %s (v%s): %s",
864 new_revision.name,
865 new_revision.version,
866 [p.name for p in new_revision.parents],
867 )
868 new_revision.columns = validated_node.columns or []
869 return new_revision
872@router.patch("/nodes/{name}/", response_model=NodeOutput)
873def update_a_node(
874 name: str,
875 data: UpdateNode,
876 *,
877 session: Session = Depends(get_session),
878) -> NodeOutput:
879 """
880 Update a node.
881 """
883 query = (
884 select(Node)
885 .where(Node.name == name)
886 .with_for_update()
887 .execution_options(populate_existing=True)
888 )
889 node = session.exec(query).one_or_none()
890 if not node:
891 raise DJException(
892 message=f"A node with name `{name}` does not exist.",
893 http_status_code=404,
894 )
896 old_revision = node.current
897 new_revision = create_new_revision_from_existing(session, old_revision, node, data)
899 if not new_revision:
900 return node # type: ignore
902 node.current_version = new_revision.version
904 new_revision.extra_validation()
906 session.add(new_revision)
907 session.add(node)
908 session.commit()
909 session.refresh(node.current)
910 return node # type: ignore
913@router.get("/nodes/similarity/{node1_name}/{node2_name}")
914def calculate_node_similarity(
915 node1_name: str, node2_name: str, *, session: Session = Depends(get_session)
916) -> JSONResponse:
917 """
918 Compare two nodes by how similar their queries are
919 """
920 node1 = get_node_by_name(session=session, name=node1_name)
921 node2 = get_node_by_name(session=session, name=node2_name)
922 if NodeType.SOURCE in (node1.type, node2.type):
923 raise DJException(
924 message="Cannot determine similarity of source nodes",
925 http_status_code=HTTPStatus.CONFLICT,
926 )
927 node1_ast = parse(node1.current.query) # type: ignore
928 node2_ast = parse(node2.current.query) # type: ignore
929 similarity = node1_ast.similarity_score(node2_ast)
930 return JSONResponse(status_code=200, content={"similarity": similarity})
933@router.get("/nodes/{name}/downstream/", response_model=List[NodeOutput])
934def list_downstream_nodes(
935 name: str, *, node_type: NodeType = None, session: Session = Depends(get_session)
936) -> List[NodeOutput]:
937 """
938 List all nodes that are downstream from the given node, filterable by type.
939 """
940 return get_downstream_nodes(session, name, node_type) # type: ignore