Allow traversal over the trie structure

Previously, there was no way to traverse the trie structure and no way
to identify specific nodes in the trie. That made it impossible to
analyze the trie structure resulting from loading a set of flags. This
change adds type and selector properties to nodes as well as access to
the child nodes of a node to allow for the structure to be analyzed.

Bug: 202154151
Test: m out/soong/hiddenapi/hiddenapi-flags.csv
      atest --host signature_trie_test verify_overlaps_test
      pyformat -s 4 --force_quote_type double -i scripts/hiddenapi/signature_trie*
      /usr/bin/pylint --rcfile $ANDROID_BUILD_TOP/tools/repohooks/tools/pylintrc scripts/hiddenapi/signature_trie*
Change-Id: Ia4714dbf59f6fd143aa3bf3ad1a59cd073d2175b
This commit is contained in:
Paul Duffin
2022-03-09 14:28:34 +00:00
parent ea93542e90
commit 92532e72a1
2 changed files with 104 additions and 9 deletions

View File

@@ -22,6 +22,19 @@ from itertools import chain
@dataclasses.dataclass()
class Node:
"""A node in the signature trie."""
# The type of the node.
#
# Leaf nodes are of type "member".
# Interior nodes can be either "package", or "class".
type: str
# The selector of the node.
#
# That is a string that can be used to select the node, e.g. in a pattern
# that is passed to InteriorNode.get_matching_rows().
selector: str
def values(self, selector):
"""Get the values from a set of selected nodes.
@@ -48,6 +61,10 @@ class Node:
"""
raise NotImplementedError("Please Implement this method")
def child_nodes(self):
"""Get an iterable of the child nodes of this node."""
raise NotImplementedError("Please Implement this method")
# pylint: disable=line-too-long
@dataclasses.dataclass()
@@ -173,22 +190,68 @@ class InteriorNode(Node):
element_type, _ = InteriorNode.split_element(element)
return element_type
def add(self, signature, value):
@staticmethod
def elements_to_selector(elements):
"""Compute a selector for a set of elements.
A selector uniquely identifies a specific Node in the trie. It is
essentially a prefix of a signature (without the leading L).
e.g. a trie containing "Ljava/lang/Object;->String()Ljava/lang/String;"
would contain nodes with the following selectors:
* "java"
* "java/lang"
* "java/lang/Object"
* "java/lang/Object;->String()Ljava/lang/String;"
"""
signature = ""
preceding_type = ""
for element in elements:
element_type, element_value = InteriorNode.split_element(element)
separator = ""
if element_type == "package":
separator = "/"
elif element_type == "class":
if preceding_type == "class":
separator = "$"
else:
separator = "/"
elif element_type == "wildcard":
separator = "/"
elif element_type == "member":
separator += ";->"
if signature:
signature += separator
signature += element_value
preceding_type = element_type
return signature
def add(self, signature, value, only_if_matches=False):
"""Associate the value with the specific signature.
:param signature: the member signature
:param value: the value to associated with the signature
:param only_if_matches: True if the value is added only if the signature
matches at least one of the existing top level packages.
:return: n/a
"""
# Split the signature into elements.
elements = self.signature_to_elements(signature)
# Find the Node associated with the deepest class.
node = self
for element in elements[:-1]:
for index, element in enumerate(elements[:-1]):
if element in node.nodes:
node = node.nodes[element]
elif only_if_matches and index == 0:
return
else:
next_node = InteriorNode()
selector = self.elements_to_selector(elements[0:index + 1])
next_node = InteriorNode(
type=InteriorNode.element_type(element), selector=selector)
node.nodes[element] = next_node
node = next_node
# Add a Leaf containing the value and associate it with the member
@@ -201,7 +264,12 @@ class InteriorNode(Node):
"specific member")
if last_element in node.nodes:
raise Exception(f"Duplicate signature: {signature}")
node.nodes[last_element] = Leaf(value)
leaf = Leaf(
type=last_element_type,
selector=signature,
value=value,
)
node.nodes[last_element] = leaf
def get_matching_rows(self, pattern):
"""Get the values (plural) associated with the pattern.
@@ -212,10 +280,6 @@ class InteriorNode(Node):
If the pattern is a class then this will return a list containing the
values associated with all members of that class.
If the pattern is a package then this will return a list containing the
values associated with all the members of all the classes in that
package and sub-packages.
If the pattern ends with "*" then the preceding part is treated as a
package and this will return a list containing the values associated
with all the members of all the classes in that package.
@@ -261,6 +325,9 @@ class InteriorNode(Node):
if selector(key):
node.append_values(values, lambda x: True)
def child_nodes(self):
return self.nodes.values()
@dataclasses.dataclass()
class Leaf(Node):
@@ -275,6 +342,9 @@ class Leaf(Node):
def append_values(self, values, selector):
values.append([self.value])
def child_nodes(self):
return []
def signature_trie():
return InteriorNode()
return InteriorNode(type="root", selector="")

View File

@@ -27,6 +27,10 @@ class TestSignatureToElements(unittest.TestCase):
def signature_to_elements(signature):
return InteriorNode.signature_to_elements(signature)
@staticmethod
def elements_to_signature(elements):
return InteriorNode.elements_to_selector(elements)
def test_nested_inner_classes(self):
elements = [
("package", "java"),
@@ -38,6 +42,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_basic_member(self):
elements = [
@@ -48,6 +53,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "Ljava/lang/Object;->hashCode()I"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_double_dollar_class(self):
elements = [
@@ -61,6 +67,7 @@ class TestSignatureToElements(unittest.TestCase):
signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;" \
"-><init>(Ljava/lang/CharSequence;)V"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_no_member(self):
elements = [
@@ -72,6 +79,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_wildcard(self):
elements = [
@@ -81,6 +89,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "java/lang/*"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_recursive_wildcard(self):
elements = [
@@ -90,6 +99,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "java/lang/**"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_no_packages_wildcard(self):
elements = [
@@ -97,6 +107,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "*"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_no_packages_recursive_wildcard(self):
elements = [
@@ -104,6 +115,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "**"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_invalid_no_class_or_wildcard(self):
signature = "java/lang"
@@ -121,6 +133,7 @@ class TestSignatureToElements(unittest.TestCase):
]
signature = "Ljavax/crypto/extObjectInputStream"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_invalid_pattern_wildcard(self):
pattern = "Ljava/lang/Class*"
@@ -200,6 +213,18 @@ Ljava/util/zip/ZipFile;-><clinit>()V
"Ljava/util/zip/ZipFile;-><clinit>()V",
])
def test_node_wildcard(self):
trie = self.read_trie()
node = list(trie.child_nodes())[0]
self.check_node_patterns(node, "**", [
"Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
"Ljava/lang/Character;->serialVersionUID:J",
"Ljava/lang/Object;->hashCode()I",
"Ljava/lang/Object;->toString()Ljava/lang/String;",
"Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
"Ljava/util/zip/ZipFile;-><clinit>()V",
])
# pylint: enable=line-too-long