Coverage for phml\utils\transform\sanitize\clean.py: 7%
69 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 11:07 -0600
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 11:07 -0600
1# pylint: disable=missing-module-docstring
2from re import match
3from typing import Optional
5from phml.nodes import AST, Element, Root
7from .schema import Schema
10def sanatize(tree: AST | Root | Element, schema: Optional[Schema] = Schema()):
11 """Sanatize elements and attributes in the phml tree. Should be used when using
12 data from an unkown source. It should be used with an AST that has already been
13 compiled to html to no unkown values are unchecked.
15 By default the sanatization schema uses the github schema and follows the hast
16 sanatize utility.
18 * [github schema](https://github.com/syntax-tree/hast-util-sanitize/blob/main/lib/schema.js)
19 * [hast sanatize](https://github.com/syntax-tree/hast-util-sanitize)
21 Note:
22 This utility will edit the tree in place.
24 Args:
25 tree (AST | Root | Element): The root of the tree that will be sanatized.
26 schema (Optional[Schema], optional): User defined schema. Defaults to github schema.
27 """
29 from phml.utils import ( # pylint: disable=import-outside-toplevel
30 check,
31 is_element,
32 remove_nodes,
33 )
35 if isinstance(tree, AST):
36 src = tree.tree
37 else:
38 src = tree
40 for strip in schema.strip:
41 remove_nodes(src, ["element", {"tag": strip}])
43 def recurse_check_tag(node: Root | Element):
44 pop_els = []
45 for idx, child in enumerate(node.children):
46 if check(child, "element") and not is_element(child, schema.tag_names):
47 pop_els.append(child)
48 elif check(node.children[idx], "element"):
49 recurse_check_tag(node.children[idx])
51 for element in pop_els:
52 node.children.remove(element)
54 def recurse_check_ancestor(node: Root | Element):
55 pop_els = []
56 for idx, child in enumerate(node.children):
57 if (
58 check(child, "element")
59 and child.tag in schema.ancestors.keys()
60 and child.parent.tag not in schema.ancestors[child.tag]
61 ):
62 pop_els.append(child)
63 elif check(node.children[idx], "element"):
64 recurse_check_ancestor(node.children[idx])
66 for element in pop_els:
67 node.children.remove(element)
69 def build_valid_attributes(attributes: list) -> list[str]:
70 """Extract attributes from schema."""
71 valid_attrs = []
72 for attribute in attributes:
73 valid_attrs = (
74 [*valid_attrs, attribute]
75 if isinstance(attribute, str)
76 else [*valid_attrs, attribute[0]]
77 )
78 return valid_attrs
80 def build_remove_attr_list(properties: dict, attributes: dict, valid_attrs: list):
81 """Build the list of attributes to remove from a dict of attributes."""
82 result = []
83 for attribute in properties:
84 if attribute not in valid_attrs:
85 result.append(attribute)
86 else:
87 for attr in attributes:
88 if bool(
89 (isinstance(attr, str) and attr != attribute)
90 or (attr[0] == attribute and properties[attribute] not in attr[1:])
91 or (
92 attribute in schema.protocols
93 and not check_protocols(
94 properties[attribute], schema.protocols[attribute]
95 )
96 )
97 ):
98 result.append(attribute)
100 return result
102 def recurse_check_attributes(node: Root | Element):
103 for idx, child in enumerate(node.children):
104 if check(child, "element") and child.tag in schema.attributes.keys():
105 valid_attrs = build_valid_attributes(schema.attributes[child.tag])
107 pop_attrs = build_remove_attr_list(
108 node.children[idx].properties, schema.attributes[child.tag], valid_attrs
109 )
111 for attribute in pop_attrs:
112 node.children[idx].properties.pop(attribute, None)
114 elif check(node.children[idx], "element"):
115 recurse_check_attributes(node.children[idx])
117 def recurse_check_required(node: Root | Element):
118 for idx, child in enumerate(node.children):
119 if check(child, "element") and child.tag in schema.required.keys():
120 for attr, value in schema.required[child.tag].items():
121 if attr not in child.properties:
122 node.children[idx][attr] = value
124 elif check(node.children[idx], "element"):
125 recurse_check_required(node.children[idx])
127 def check_protocols(value: str, protocols: list[str]):
128 for protocol in protocols:
129 if match(f"{protocol}:.*", value) is not None:
130 return True
131 return False
133 recurse_check_tag(src)
134 recurse_check_ancestor(src)
135 recurse_check_attributes(src)
136 recurse_check_required(src)