Coverage for dj/models/node.py: 100%

263 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1""" 

2Model for nodes. 

3""" 

4# pylint: disable=too-many-instance-attributes 

5import enum 

6from dataclasses import dataclass 

7from datetime import datetime, timezone 

8from functools import partial 

9from typing import Dict, List, Optional 

10 

11from pydantic import BaseModel, Extra 

12from pydantic import Field as PydanticField 

13from pydantic import root_validator 

14from sqlalchemy import JSON, DateTime, String 

15from sqlalchemy.sql.schema import Column as SqlaColumn 

16from sqlalchemy.sql.schema import UniqueConstraint 

17from sqlalchemy.types import Enum 

18from sqlmodel import Field, Relationship, SQLModel 

19from typing_extensions import TypedDict 

20 

21from dj.errors import DJInvalidInputException 

22from dj.models.base import BaseSQLModel, generate_display_name 

23from dj.models.catalog import Catalog 

24from dj.models.column import Column, ColumnYAML 

25from dj.models.database import Database 

26from dj.models.engine import Dialect, Engine, EngineInfo 

27from dj.models.tag import Tag, TagNodeRelationship 

28from dj.sql.parse import is_metric 

29from dj.sql.parsing.types import ColumnType 

30from dj.typing import UTCDatetime 

31from dj.utils import Version 

32 

33DEFAULT_DRAFT_VERSION = Version(major=0, minor=1) 

34DEFAULT_PUBLISHED_VERSION = Version(major=1, minor=0) 

35 

36 

37@dataclass(frozen=True) 

38class BuildCriteria: 

39 """ 

40 Criterion used for building 

41 - used to deterimine whether to use an availability state 

42 """ 

43 

44 timestamp: Optional[UTCDatetime] = None 

45 dialect: Dialect = Dialect.SPARK 

46 

47 

48class NodeRelationship(BaseSQLModel, table=True): # type: ignore 

49 """ 

50 Join table for self-referential many-to-many relationships between nodes. 

51 """ 

52 

53 parent_id: Optional[int] = Field( 

54 default=None, 

55 foreign_key="node.id", 

56 primary_key=True, 

57 ) 

58 

59 # This will default to `latest`, which points to the current version of the node, 

60 # or it can be a specific version. 

61 parent_version: Optional[str] = Field( 

62 default="latest", 

63 ) 

64 

65 child_id: Optional[int] = Field( 

66 default=None, 

67 foreign_key="noderevision.id", 

68 primary_key=True, 

69 ) 

70 

71 

72class CubeRelationship(BaseSQLModel, table=True): # type: ignore 

73 """ 

74 Join table for many-to-many relationships between cube nodes and metric/dimension nodes. 

75 """ 

76 

77 __tablename__ = "cube" 

78 

79 cube_id: Optional[int] = Field( 

80 default=None, 

81 foreign_key="noderevision.id", 

82 primary_key=True, 

83 ) 

84 

85 cube_element_id: Optional[int] = Field( 

86 default=None, 

87 foreign_key="node.id", 

88 primary_key=True, 

89 ) 

90 

91 

92class NodeColumns(BaseSQLModel, table=True): # type: ignore 

93 """ 

94 Join table for node columns. 

95 """ 

96 

97 node_id: Optional[int] = Field( 

98 default=None, 

99 foreign_key="noderevision.id", 

100 primary_key=True, 

101 ) 

102 column_id: Optional[int] = Field( 

103 default=None, 

104 foreign_key="column.id", 

105 primary_key=True, 

106 ) 

107 

108 

109class NodeType(str, enum.Enum): 

110 """ 

111 Node type. 

112 

113 A node can have 4 types, currently: 

114 

115 1. SOURCE nodes are root nodes in the DAG, and point to tables or views in a DB. 

116 2. TRANSFORM nodes are SQL transformations, reading from SOURCE/TRANSFORM nodes. 

117 3. METRIC nodes are leaves in the DAG, and have a single aggregation query. 

118 4. DIMENSION nodes are special SOURCE nodes that can be auto-joined with METRICS. 

119 5. CUBE nodes contain a reference to a set of METRICS and a set of DIMENSIONS. 

120 """ 

121 

122 SOURCE = "source" 

123 TRANSFORM = "transform" 

124 METRIC = "metric" 

125 DIMENSION = "dimension" 

126 CUBE = "cube" 

127 

128 

129class NodeMode(str, enum.Enum): 

130 """ 

131 Node mode. 

132 

133 A node can be in one of the following modes: 

134 

135 1. PUBLISHED - Must be valid and not cause any child nodes to be invalid 

136 2. DRAFT - Can be invalid, have invalid parents, and include dangling references 

137 """ 

138 

139 PUBLISHED = "published" 

140 DRAFT = "draft" 

141 

142 

143class NodeStatus(str, enum.Enum): 

144 """ 

145 Node status. 

146 

147 A node can have one of the following statuses: 

148 

149 1. VALID - All references to other nodes and node columns are valid 

150 2. INVALID - One or more parent nodes are incompatible or do not exist 

151 """ 

152 

153 VALID = "valid" 

154 INVALID = "invalid" 

155 

156 

157class NodeYAML(TypedDict, total=False): 

158 """ 

159 Schema of a node in the YAML file. 

160 """ 

161 

162 description: str 

163 display_name: str 

164 type: NodeType 

165 query: str 

166 columns: Dict[str, ColumnYAML] 

167 

168 

169class NodeBase(BaseSQLModel): 

170 """ 

171 A base node. 

172 """ 

173 

174 name: str = Field(sa_column=SqlaColumn("name", String, unique=True)) 

175 type: NodeType = Field(sa_column=SqlaColumn(Enum(NodeType))) 

176 display_name: Optional[str] = Field( 

177 sa_column=SqlaColumn( 

178 "display_name", 

179 String, 

180 default=generate_display_name("name"), 

181 ), 

182 max_length=100, 

183 ) 

184 

185 

186class NodeRevisionBase(BaseSQLModel): 

187 """ 

188 A base node revision. 

189 """ 

190 

191 name: str = Field( 

192 sa_column=SqlaColumn("name", String, unique=False), 

193 foreign_key="node.name", 

194 ) 

195 display_name: Optional[str] = Field( 

196 sa_column=SqlaColumn( 

197 "display_name", 

198 String, 

199 default=generate_display_name("name"), 

200 ), 

201 ) 

202 type: NodeType = Field(sa_column=SqlaColumn(Enum(NodeType))) 

203 description: str = "" 

204 query: Optional[str] = None 

205 mode: NodeMode = NodeMode.PUBLISHED 

206 

207 

208class MissingParent(BaseSQLModel, table=True): # type: ignore 

209 """ 

210 A missing parent node 

211 """ 

212 

213 id: Optional[int] = Field(default=None, primary_key=True) 

214 name: str = Field(sa_column=SqlaColumn("name", String)) 

215 created_at: UTCDatetime = Field( 

216 sa_column=SqlaColumn(DateTime(timezone=True)), 

217 default_factory=partial(datetime.now, timezone.utc), 

218 ) 

219 

220 

221class NodeMissingParents(BaseSQLModel, table=True): # type: ignore 

222 """ 

223 Join table for missing parents 

224 """ 

225 

226 missing_parent_id: Optional[int] = Field( 

227 default=None, 

228 foreign_key="missingparent.id", 

229 primary_key=True, 

230 ) 

231 referencing_node_id: Optional[int] = Field( 

232 default=None, 

233 foreign_key="noderevision.id", 

234 primary_key=True, 

235 ) 

236 

237 

238class AvailabilityStateBase(BaseSQLModel): 

239 """ 

240 An availability state base 

241 """ 

242 

243 catalog: str 

244 schema_: Optional[str] = Field(default=None) 

245 table: str 

246 valid_through_ts: int 

247 max_partition: List[str] = Field(sa_column=SqlaColumn(JSON)) 

248 min_partition: List[str] = Field(sa_column=SqlaColumn(JSON)) 

249 

250 

251class AvailabilityState(AvailabilityStateBase, table=True): # type: ignore 

252 """ 

253 The availability of materialized data for a node 

254 """ 

255 

256 id: Optional[int] = Field(default=None, primary_key=True) 

257 updated_at: UTCDatetime = Field( 

258 sa_column=SqlaColumn(DateTime(timezone=True)), 

259 default_factory=partial(datetime.now, timezone.utc), 

260 ) 

261 

262 def is_available( 

263 self, 

264 criteria: Optional[BuildCriteria] = None, # pylint: disable=unused-argument 

265 ) -> bool: # pragma: no cover 

266 """ 

267 Determine whether an availability state is useable given criteria 

268 """ 

269 # Criteria to determine if an availability state should be used needs to be added 

270 return True 

271 

272 

273class NodeAvailabilityState(BaseSQLModel, table=True): # type: ignore 

274 """ 

275 Join table for availability state 

276 """ 

277 

278 availability_id: Optional[int] = Field( 

279 default=None, 

280 foreign_key="availabilitystate.id", 

281 primary_key=True, 

282 ) 

283 node_id: Optional[int] = Field( 

284 default=None, 

285 foreign_key="noderevision.id", 

286 primary_key=True, 

287 ) 

288 

289 

290class NodeNamespace(SQLModel, table=True): # type: ignore 

291 """ 

292 A node namespace 

293 """ 

294 

295 namespace: str = Field(nullable=False, unique=True, primary_key=True) 

296 

297 

298class Node(NodeBase, table=True): # type: ignore 

299 """ 

300 Node that acts as an umbrella for all node revisions 

301 """ 

302 

303 __table_args__ = ( 

304 UniqueConstraint("name", "namespace", name="unique_node_namespace_name"), 

305 ) 

306 

307 id: Optional[int] = Field(default=None, primary_key=True) 

308 namespace: Optional[str] = "default" 

309 current_version: str = Field(default=str(DEFAULT_DRAFT_VERSION)) 

310 created_at: UTCDatetime = Field( 

311 sa_column=SqlaColumn(DateTime(timezone=True)), 

312 default_factory=partial(datetime.now, timezone.utc), 

313 ) 

314 

315 revisions: List["NodeRevision"] = Relationship(back_populates="node") 

316 cubes: List["NodeRevision"] = Relationship(back_populates="cube_elements") 

317 current: "NodeRevision" = Relationship( 

318 sa_relationship_kwargs={ 

319 "primaryjoin": "and_(Node.id==NodeRevision.node_id, " 

320 "Node.current_version == NodeRevision.version)", 

321 "viewonly": True, 

322 "uselist": False, 

323 }, 

324 ) 

325 

326 children: List["NodeRevision"] = Relationship( 

327 back_populates="parents", 

328 link_model=NodeRelationship, 

329 sa_relationship_kwargs={ 

330 "primaryjoin": "Node.id==NodeRelationship.parent_id", 

331 "secondaryjoin": "NodeRevision.id==NodeRelationship.child_id", 

332 }, 

333 ) 

334 

335 tags: List["Tag"] = Relationship( 

336 back_populates="nodes", 

337 link_model=TagNodeRelationship, 

338 sa_relationship_kwargs={ 

339 "primaryjoin": "TagNodeRelationship.node_id==Node.id", 

340 "secondaryjoin": "TagNodeRelationship.tag_id==Tag.id", 

341 }, 

342 ) 

343 

344 def __hash__(self) -> int: 

345 return hash(self.id) 

346 

347 

348class MaterializationConfig(BaseSQLModel, table=True): # type: ignore 

349 """ 

350 Materialization configuration for a node and specific engines. 

351 """ 

352 

353 node_revision_id: int = Field(foreign_key="noderevision.id", primary_key=True) 

354 node_revision: "NodeRevision" = Relationship( 

355 back_populates="materialization_configs", 

356 ) 

357 

358 engine_id: int = Field(foreign_key="engine.id", primary_key=True) 

359 engine: Engine = Relationship() 

360 

361 config: str = Field(nullable=False) 

362 

363 

364class NodeRevision(NodeRevisionBase, table=True): # type: ignore 

365 """ 

366 A node revision. 

367 """ 

368 

369 __table_args__ = (UniqueConstraint("version", "node_id"),) 

370 

371 id: Optional[int] = Field(default=None, primary_key=True) 

372 version: Optional[str] = Field(default=str(DEFAULT_DRAFT_VERSION)) 

373 node_id: Optional[int] = Field(foreign_key="node.id") 

374 node: Node = Relationship(back_populates="revisions") 

375 catalog_id: int = Field(default=None, foreign_key="catalog.id") 

376 catalog: Catalog = Relationship( 

377 back_populates="node_revisions", 

378 sa_relationship_kwargs={ 

379 "lazy": "joined", 

380 }, 

381 ) 

382 schema_: Optional[str] = None 

383 table: Optional[str] = None 

384 cube_elements: List["Node"] = Relationship( # Only used by cube nodes 

385 back_populates="cubes", 

386 link_model=CubeRelationship, 

387 sa_relationship_kwargs={ 

388 "primaryjoin": "NodeRevision.id==CubeRelationship.cube_id", 

389 "secondaryjoin": "Node.id==CubeRelationship.cube_element_id", 

390 "lazy": "joined", 

391 }, 

392 ) 

393 status: NodeStatus = NodeStatus.INVALID 

394 updated_at: UTCDatetime = Field( 

395 sa_column=SqlaColumn(DateTime(timezone=True)), 

396 default_factory=partial(datetime.now, timezone.utc), 

397 ) 

398 

399 parents: List["Node"] = Relationship( 

400 back_populates="children", 

401 link_model=NodeRelationship, 

402 sa_relationship_kwargs={ 

403 "primaryjoin": "NodeRevision.id==NodeRelationship.child_id", 

404 "secondaryjoin": "Node.id==NodeRelationship.parent_id", 

405 }, 

406 ) 

407 

408 parent_links: List[NodeRelationship] = Relationship() 

409 

410 missing_parents: List[MissingParent] = Relationship( 

411 link_model=NodeMissingParents, 

412 sa_relationship_kwargs={ 

413 "primaryjoin": "NodeRevision.id==NodeMissingParents.referencing_node_id", 

414 "secondaryjoin": "MissingParent.id==NodeMissingParents.missing_parent_id", 

415 "cascade": "all, delete", 

416 }, 

417 ) 

418 

419 columns: List[Column] = Relationship( 

420 link_model=NodeColumns, 

421 sa_relationship_kwargs={ 

422 "primaryjoin": "NodeRevision.id==NodeColumns.node_id", 

423 "secondaryjoin": "Column.id==NodeColumns.column_id", 

424 "cascade": "all, delete", 

425 }, 

426 ) 

427 

428 # The availability of materialized data needs to be stored on the NodeRevision 

429 # level in order to support pinned versions, where a node owner wants to pin 

430 # to a particular upstream node version. 

431 availability: Optional[AvailabilityState] = Relationship( 

432 link_model=NodeAvailabilityState, 

433 sa_relationship_kwargs={ 

434 "primaryjoin": "NodeRevision.id==NodeAvailabilityState.node_id", 

435 "secondaryjoin": "AvailabilityState.id==NodeAvailabilityState.availability_id", 

436 "cascade": "all, delete", 

437 "uselist": False, 

438 }, 

439 ) 

440 

441 # Nodes of type SOURCE will not have this property as their materialization 

442 # is not managed as a part of this service 

443 materialization_configs: List[MaterializationConfig] = Relationship( 

444 back_populates="node_revision", 

445 ) 

446 

447 def __hash__(self) -> int: 

448 return hash(self.id) 

449 

450 def primary_key(self) -> List[Column]: 

451 """ 

452 Returns the primary key columns of this node. 

453 """ 

454 primary_key_columns = [] 

455 for col in self.columns: # pylint: disable=not-an-iterable 

456 if "primary_key" in {attr.attribute_type.name for attr in col.attributes}: 

457 primary_key_columns.append(col) 

458 return primary_key_columns 

459 

460 def extra_validation(self) -> None: 

461 """ 

462 Extra validation for node data. 

463 """ 

464 if self.type in (NodeType.SOURCE, NodeType.CUBE): 

465 if self.query: 

466 raise DJInvalidInputException( 

467 f"Node {self.name} of type {self.type} should not have a query", 

468 ) 

469 

470 if self.type in {NodeType.TRANSFORM, NodeType.METRIC, NodeType.DIMENSION}: 

471 if not self.query: 

472 raise DJInvalidInputException( 

473 f"Node {self.name} of type {self.type} needs a query", 

474 ) 

475 

476 if self.type == NodeType.METRIC: 

477 if not is_metric(self.query): 

478 raise DJInvalidInputException( 

479 f"Node {self.name} of type metric has an invalid query, " 

480 "should have a single aggregation", 

481 ) 

482 

483 if self.type == NodeType.CUBE: 

484 if not self.cube_elements: 

485 raise DJInvalidInputException( 

486 f"Node {self.name} of type cube node needs cube elements", 

487 ) 

488 

489 

490class ImmutableNodeFields(BaseSQLModel): 

491 """ 

492 Node fields that cannot be changed 

493 """ 

494 

495 name: str 

496 namespace: str = "default" 

497 

498 

499class MutableNodeFields(BaseSQLModel): 

500 """ 

501 Node fields that can be changed. 

502 """ 

503 

504 display_name: Optional[str] 

505 description: str 

506 mode: NodeMode 

507 primary_key: Optional[List[str]] 

508 

509 

510class MutableNodeQueryField(BaseSQLModel): 

511 """ 

512 Query field for node. 

513 """ 

514 

515 query: str 

516 

517 

518class NodeNameOutput(SQLModel): 

519 """ 

520 Node name only 

521 """ 

522 

523 name: str 

524 

525 

526class AttributeTypeName(BaseSQLModel): 

527 """ 

528 Attribute type name. 

529 """ 

530 

531 namespace: str 

532 name: str 

533 

534 

535class AttributeOutput(BaseSQLModel): 

536 """ 

537 Column attribute output. 

538 """ 

539 

540 attribute_type: AttributeTypeName 

541 

542 

543class ColumnOutput(SQLModel): 

544 """ 

545 A simplified column schema, without ID or dimensions. 

546 """ 

547 

548 name: str 

549 type: ColumnType 

550 attributes: Optional[List[AttributeOutput]] 

551 dimension: Optional[NodeNameOutput] 

552 

553 class Config: # pylint: disable=too-few-public-methods 

554 """ 

555 Should perform validation on assignment 

556 """ 

557 

558 validate_assignment = True 

559 

560 @root_validator 

561 def type_string(cls, values): # pylint: disable=no-self-argument 

562 """ 

563 Extracts the type as a string 

564 """ 

565 values["type"] = str(values.get("type")) 

566 return values 

567 

568 

569class SourceColumnOutput(SQLModel): 

570 """ 

571 A column used in creation of a source node 

572 """ 

573 

574 name: str 

575 type: ColumnType 

576 attributes: Optional[List[AttributeOutput]] 

577 dimension: Optional[str] 

578 

579 class Config: # pylint: disable=too-few-public-methods 

580 """ 

581 Should perform validation on assignment 

582 """ 

583 

584 validate_assignment = True 

585 

586 @root_validator 

587 def type_string(cls, values): # pylint: disable=no-self-argument 

588 """ 

589 Extracts the type as a string 

590 """ 

591 values["type"] = str(values.get("type")) 

592 return values 

593 

594 

595class SourceNodeFields(BaseSQLModel): 

596 """ 

597 Source node fields that can be changed. 

598 """ 

599 

600 catalog: str 

601 schema_: str 

602 table: str 

603 columns: Optional[List["SourceColumnOutput"]] = [] 

604 

605 

606class CubeNodeFields(BaseSQLModel): 

607 """ 

608 Cube node fields that can be changed 

609 """ 

610 

611 display_name: Optional[str] 

612 cube_elements: List[str] 

613 description: str 

614 mode: NodeMode 

615 

616 

617# 

618# Create and Update objects 

619# 

620 

621 

622class CreateNode(ImmutableNodeFields, MutableNodeFields, MutableNodeQueryField): 

623 """ 

624 Create non-source node object. 

625 """ 

626 

627 

628class CreateSourceNode(ImmutableNodeFields, MutableNodeFields, SourceNodeFields): 

629 """ 

630 A create object for source nodes 

631 """ 

632 

633 

634class CreateCubeNode(ImmutableNodeFields, CubeNodeFields): 

635 """ 

636 A create object for cube nodes 

637 """ 

638 

639 class Config: # pylint: disable=too-few-public-methods 

640 """ 

641 Do not allow extra fields in input 

642 """ 

643 

644 extra = Extra.forbid 

645 

646 

647class UpdateNode(MutableNodeFields, SourceNodeFields): 

648 """ 

649 Update node object where all fields are optional 

650 """ 

651 

652 __annotations__ = { 

653 k: Optional[v] 

654 for k, v in { 

655 **SourceNodeFields.__annotations__, # pylint: disable=E1101 

656 **MutableNodeFields.__annotations__, # pylint: disable=E1101 

657 **MutableNodeQueryField.__annotations__, # pylint: disable=E1101 

658 }.items() 

659 } 

660 

661 class Config: # pylint: disable=too-few-public-methods 

662 """ 

663 Do not allow fields other than the ones defined here. 

664 """ 

665 

666 extra = Extra.forbid 

667 

668 

669class UpsertMaterializationConfig(BaseSQLModel): 

670 """ 

671 An upsert object for materialization configs 

672 """ 

673 

674 engine_name: str 

675 engine_version: str 

676 config: str 

677 

678 

679# 

680# Response output objects 

681# 

682 

683 

684class OutputModel(BaseModel): 

685 """ 

686 An output model with the ability to flatten fields. When fields are created with 

687 `Field(flatten=True)`, the field's values will be automatically flattened into the 

688 parent output model. 

689 """ 

690 

691 def _iter(self, *args, to_dict: bool = False, **kwargs): 

692 for dict_key, value in super()._iter(to_dict, *args, **kwargs): 

693 if to_dict and self.__fields__[dict_key].field_info.extra.get( 

694 "flatten", 

695 False, 

696 ): 

697 assert isinstance(value, dict) 

698 for key, val in value.items(): 

699 yield key, val 

700 else: 

701 yield dict_key, value 

702 

703 

704class TableOutput(SQLModel): 

705 """ 

706 Output for table information. 

707 """ 

708 

709 id: Optional[int] 

710 catalog: Optional[Catalog] 

711 schema_: Optional[str] 

712 table: Optional[str] 

713 database: Optional[Database] 

714 

715 

716class MaterializationConfigOutput(SQLModel): 

717 """ 

718 Output for materialization config. 

719 """ 

720 

721 engine: EngineInfo 

722 config: str 

723 

724 

725class NodeRevisionOutput(SQLModel): 

726 """ 

727 Output for a node revision with information about columns and if it is a metric. 

728 """ 

729 

730 id: int = Field(alias="node_revision_id") 

731 node_id: int 

732 type: NodeType 

733 name: str 

734 display_name: str 

735 version: str 

736 status: NodeStatus 

737 mode: NodeMode 

738 catalog: Optional[Catalog] 

739 schema_: Optional[str] 

740 table: Optional[str] 

741 description: str = "" 

742 query: Optional[str] = None 

743 availability: Optional[AvailabilityState] = None 

744 columns: List[ColumnOutput] 

745 updated_at: UTCDatetime 

746 materialization_configs: List[MaterializationConfigOutput] 

747 parents: List[NodeNameOutput] 

748 

749 class Config: # pylint: disable=missing-class-docstring,too-few-public-methods 

750 allow_population_by_field_name = True 

751 

752 

753class NodeOutput(OutputModel): 

754 """ 

755 Output for a node that shows the current revision. 

756 """ 

757 

758 namespace: str 

759 current: NodeRevisionOutput = PydanticField(flatten=True) 

760 created_at: UTCDatetime 

761 tags: List["Tag"] = [] 

762 

763 

764class NodeValidation(SQLModel): 

765 """ 

766 A validation of a provided node definition 

767 """ 

768 

769 message: str 

770 status: NodeStatus 

771 node_revision: NodeRevision 

772 dependencies: List[NodeRevisionOutput] 

773 columns: List[Column]