diff --git a/scripts/hiddenapi/signature_trie.py b/scripts/hiddenapi/signature_trie.py index 5871834c2..e813a9781 100644 --- a/scripts/hiddenapi/signature_trie.py +++ b/scripts/hiddenapi/signature_trie.py @@ -22,6 +22,19 @@ from itertools import chain @dataclasses.dataclass() class Node: + """A node in the signature trie.""" + + # The type of the node. + # + # Leaf nodes are of type "member". + # Interior nodes can be either "package", or "class". + type: str + + # The selector of the node. + # + # That is a string that can be used to select the node, e.g. in a pattern + # that is passed to InteriorNode.get_matching_rows(). + selector: str def values(self, selector): """Get the values from a set of selected nodes. @@ -48,6 +61,10 @@ class Node: """ raise NotImplementedError("Please Implement this method") + def child_nodes(self): + """Get an iterable of the child nodes of this node.""" + raise NotImplementedError("Please Implement this method") + # pylint: disable=line-too-long @dataclasses.dataclass() @@ -173,22 +190,68 @@ class InteriorNode(Node): element_type, _ = InteriorNode.split_element(element) return element_type - def add(self, signature, value): + @staticmethod + def elements_to_selector(elements): + """Compute a selector for a set of elements. + + A selector uniquely identifies a specific Node in the trie. It is + essentially a prefix of a signature (without the leading L). + + e.g. a trie containing "Ljava/lang/Object;->String()Ljava/lang/String;" + would contain nodes with the following selectors: + * "java" + * "java/lang" + * "java/lang/Object" + * "java/lang/Object;->String()Ljava/lang/String;" + """ + signature = "" + preceding_type = "" + for element in elements: + element_type, element_value = InteriorNode.split_element(element) + separator = "" + if element_type == "package": + separator = "/" + elif element_type == "class": + if preceding_type == "class": + separator = "$" + else: + separator = "/" + elif element_type == "wildcard": + separator = "/" + elif element_type == "member": + separator += ";->" + + if signature: + signature += separator + + signature += element_value + + preceding_type = element_type + + return signature + + def add(self, signature, value, only_if_matches=False): """Associate the value with the specific signature. :param signature: the member signature :param value: the value to associated with the signature + :param only_if_matches: True if the value is added only if the signature + matches at least one of the existing top level packages. :return: n/a """ # Split the signature into elements. elements = self.signature_to_elements(signature) # Find the Node associated with the deepest class. node = self - for element in elements[:-1]: + for index, element in enumerate(elements[:-1]): if element in node.nodes: node = node.nodes[element] + elif only_if_matches and index == 0: + return else: - next_node = InteriorNode() + selector = self.elements_to_selector(elements[0:index + 1]) + next_node = InteriorNode( + type=InteriorNode.element_type(element), selector=selector) node.nodes[element] = next_node node = next_node # Add a Leaf containing the value and associate it with the member @@ -201,7 +264,12 @@ class InteriorNode(Node): "specific member") if last_element in node.nodes: raise Exception(f"Duplicate signature: {signature}") - node.nodes[last_element] = Leaf(value) + leaf = Leaf( + type=last_element_type, + selector=signature, + value=value, + ) + node.nodes[last_element] = leaf def get_matching_rows(self, pattern): """Get the values (plural) associated with the pattern. @@ -212,10 +280,6 @@ class InteriorNode(Node): If the pattern is a class then this will return a list containing the values associated with all members of that class. - If the pattern is a package then this will return a list containing the - values associated with all the members of all the classes in that - package and sub-packages. - If the pattern ends with "*" then the preceding part is treated as a package and this will return a list containing the values associated with all the members of all the classes in that package. @@ -261,6 +325,9 @@ class InteriorNode(Node): if selector(key): node.append_values(values, lambda x: True) + def child_nodes(self): + return self.nodes.values() + @dataclasses.dataclass() class Leaf(Node): @@ -275,6 +342,9 @@ class Leaf(Node): def append_values(self, values, selector): values.append([self.value]) + def child_nodes(self): + return [] + def signature_trie(): - return InteriorNode() + return InteriorNode(type="root", selector="") diff --git a/scripts/hiddenapi/signature_trie_test.py b/scripts/hiddenapi/signature_trie_test.py index e9644efbc..129569107 100755 --- a/scripts/hiddenapi/signature_trie_test.py +++ b/scripts/hiddenapi/signature_trie_test.py @@ -27,6 +27,10 @@ class TestSignatureToElements(unittest.TestCase): def signature_to_elements(signature): return InteriorNode.signature_to_elements(signature) + @staticmethod + def elements_to_signature(elements): + return InteriorNode.elements_to_selector(elements) + def test_nested_inner_classes(self): elements = [ ("package", "java"), @@ -38,6 +42,7 @@ class TestSignatureToElements(unittest.TestCase): ] signature = "Ljava/lang/ProcessBuilder$Redirect$1;->()V" self.assertEqual(elements, self.signature_to_elements(signature)) + self.assertEqual(signature, "L" + self.elements_to_signature(elements)) def test_basic_member(self): elements = [ @@ -48,6 +53,7 @@ class TestSignatureToElements(unittest.TestCase): ] signature = "Ljava/lang/Object;->hashCode()I" self.assertEqual(elements, self.signature_to_elements(signature)) + self.assertEqual(signature, "L" + self.elements_to_signature(elements)) def test_double_dollar_class(self): elements = [ @@ -61,6 +67,7 @@ class TestSignatureToElements(unittest.TestCase): signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;" \ "->(Ljava/lang/CharSequence;)V" self.assertEqual(elements, self.signature_to_elements(signature)) + self.assertEqual(signature, "L" + self.elements_to_signature(elements)) def test_no_member(self): elements = [ @@ -72,6 +79,7 @@ class TestSignatureToElements(unittest.TestCase): ] signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0" self.assertEqual(elements, self.signature_to_elements(signature)) + self.assertEqual(signature, "L" + self.elements_to_signature(elements)) def test_wildcard(self): elements = [ @@ -81,6 +89,7 @@ class TestSignatureToElements(unittest.TestCase): ] signature = "java/lang/*" self.assertEqual(elements, self.signature_to_elements(signature)) + self.assertEqual(signature, self.elements_to_signature(elements)) def test_recursive_wildcard(self): elements = [ @@ -90,6 +99,7 @@ class TestSignatureToElements(unittest.TestCase): ] signature = "java/lang/**" self.assertEqual(elements, self.signature_to_elements(signature)) + self.assertEqual(signature, self.elements_to_signature(elements)) def test_no_packages_wildcard(self): elements = [ @@ -97,6 +107,7 @@ class TestSignatureToElements(unittest.TestCase): ] signature = "*" self.assertEqual(elements, self.signature_to_elements(signature)) + self.assertEqual(signature, self.elements_to_signature(elements)) def test_no_packages_recursive_wildcard(self): elements = [ @@ -104,6 +115,7 @@ class TestSignatureToElements(unittest.TestCase): ] signature = "**" self.assertEqual(elements, self.signature_to_elements(signature)) + self.assertEqual(signature, self.elements_to_signature(elements)) def test_invalid_no_class_or_wildcard(self): signature = "java/lang" @@ -121,6 +133,7 @@ class TestSignatureToElements(unittest.TestCase): ] signature = "Ljavax/crypto/extObjectInputStream" self.assertEqual(elements, self.signature_to_elements(signature)) + self.assertEqual(signature, "L" + self.elements_to_signature(elements)) def test_invalid_pattern_wildcard(self): pattern = "Ljava/lang/Class*" @@ -200,6 +213,18 @@ Ljava/util/zip/ZipFile;->()V "Ljava/util/zip/ZipFile;->()V", ]) + def test_node_wildcard(self): + trie = self.read_trie() + node = list(trie.child_nodes())[0] + self.check_node_patterns(node, "**", [ + "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;", + "Ljava/lang/Character;->serialVersionUID:J", + "Ljava/lang/Object;->hashCode()I", + "Ljava/lang/Object;->toString()Ljava/lang/String;", + "Ljava/lang/ProcessBuilder$Redirect$1;->()V", + "Ljava/util/zip/ZipFile;->()V", + ]) + # pylint: enable=line-too-long