diff --git a/cc/gen_stub_libs.py b/cc/gen_stub_libs.py index 4b7c244e5..9b40415d5 100755 --- a/cc/gen_stub_libs.py +++ b/cc/gen_stub_libs.py @@ -16,6 +16,7 @@ # """Generates source for stub shared libraries for the NDK.""" import argparse +import logging import os import re @@ -30,45 +31,28 @@ ALL_ARCHITECTURES = ( ) -class Scope(object): - """Enum for version script scope. - - Top: Top level of the file. - Global: In a version and visibility section where symbols should be visible - to the NDK. - Local: In a visibility section of a public version where symbols should be - hidden to the NDK. - Private: In a version where symbols should not be visible to the NDK. - """ - Top = 1 - Global = 2 - Local = 3 - Private = 4 - - -class Stack(object): - """Basic stack implementation.""" - def __init__(self): - self.stack = [] - - def push(self, obj): - """Push an item on to the stack.""" - self.stack.append(obj) - - def pop(self): - """Remove and return the item on the top of the stack.""" - return self.stack.pop() - - @property - def top(self): - """Return the top of the stack.""" - return self.stack[-1] +def logger(): + """Return the main logger for this module.""" + return logging.getLogger(__name__) def get_tags(line): """Returns a list of all tags on this line.""" _, _, all_tags = line.strip().partition('#') - return re.split(r'\s+', all_tags) + return [e for e in re.split(r'\s+', all_tags) if e.strip()] + + +def get_tag_value(tag): + """Returns the value of a key/value tag. + + Raises: + ValueError: Tag is not a key/value type tag. + + Returns: Value part of tag as a string. + """ + if '=' not in tag: + raise ValueError('Not a key/value tag: ' + tag) + return tag.partition('=')[2] def version_is_private(version): @@ -87,82 +71,11 @@ def should_omit_version(name, tags, arch, api): return True if not symbol_in_arch(tags, arch): return True - if not symbol_in_version(tags, arch, api): + if not symbol_in_api(tags, arch, api): return True return False -def enter_version(scope, line, version_file, arch, api): - """Enters a new version block scope.""" - if scope.top != Scope.Top: - raise RuntimeError('Encountered nested version block.') - - # Entering a new version block. By convention symbols with versions ending - # with "_PRIVATE" or "_PLATFORM" are not included in the NDK. - version_name = line.split('{')[0].strip() - tags = get_tags(line) - if should_omit_version(version_name, tags, arch, api): - scope.push(Scope.Private) - else: - scope.push(Scope.Global) # By default symbols are visible. - version_file.write(line) - - -def leave_version(scope, line, version_file): - """Leave a version block scope.""" - # There is no close to a visibility section, just the end of the version or - # a new visiblity section. - assert scope.top in (Scope.Global, Scope.Local, Scope.Private) - if scope.top != Scope.Private: - version_file.write(line) - scope.pop() - assert scope.top == Scope.Top - - -def enter_visibility(scope, line, version_file): - """Enters a new visibility block scope.""" - leave_visibility(scope) - version_file.write(line) - visibility = line.split(':')[0].strip() - if visibility == 'local': - scope.push(Scope.Local) - elif visibility == 'global': - scope.push(Scope.Global) - else: - raise RuntimeError('Unknown visiblity label: ' + visibility) - - -def leave_visibility(scope): - """Leaves a visibility block scope.""" - assert scope.top in (Scope.Global, Scope.Local) - scope.pop() - assert scope.top == Scope.Top - - -def handle_top_scope(scope, line, version_file, arch, api): - """Processes a line in the top level scope.""" - if '{' in line: - enter_version(scope, line, version_file, arch, api) - else: - raise RuntimeError('Unexpected contents at top level: ' + line) - - -def handle_private_scope(scope, line, version_file): - """Eats all input.""" - if '}' in line: - leave_version(scope, line, version_file) - - -def handle_local_scope(scope, line, version_file): - """Passes through input.""" - if ':' in line: - enter_visibility(scope, line, version_file) - elif '}' in line: - leave_version(scope, line, version_file) - else: - version_file.write(line) - - def symbol_in_arch(tags, arch): """Returns true if the symbol is present for the given architecture.""" has_arch_tags = False @@ -178,8 +91,8 @@ def symbol_in_arch(tags, arch): return not has_arch_tags -def symbol_in_version(tags, arch, version): - """Returns true if the symbol is present for the given version.""" +def symbol_in_api(tags, arch, api): + """Returns true if the symbol is present for the given API level.""" introduced_tag = None arch_specific = False for tag in tags: @@ -191,7 +104,7 @@ def symbol_in_version(tags, arch, version): arch_specific = True elif tag == 'future': # This symbol is not in any released API level. - # TODO(danalbert): These need to be emitted for version == current. + # TODO(danalbert): These need to be emitted for api == current. # That's not a construct we have yet, so just skip it for now. return False @@ -200,66 +113,204 @@ def symbol_in_version(tags, arch, version): # available. return True - # The tag is a key=value pair, and we only care about the value now. - _, _, version_str = introduced_tag.partition('=') - return version >= int(version_str) + return api >= int(get_tag_value(introduced_tag)) -def handle_global_scope(scope, line, src_file, version_file, arch, api): - """Emits present symbols to the version file and stub source file.""" - if ':' in line: - enter_visibility(scope, line, version_file) - return - if '}' in line: - leave_version(scope, line, version_file) - return +def symbol_versioned_in_api(tags, api): + """Returns true if the symbol should be versioned for the given API. - if ';' not in line: - raise RuntimeError('Expected ; to terminate symbol: ' + line) - if '*' in line: - raise RuntimeError('Wildcard global symbols are not permitted.') + This models the `versioned=API` tag. This should be a very uncommonly + needed tag, and is really only needed to fix versioning mistakes that are + already out in the wild. - # Line is now in the format "; # tags" - # Tags are whitespace separated. - symbol_name, _, rest = line.strip().partition(';') - tags = get_tags(line) - - if not symbol_in_arch(tags, arch): - return - if not symbol_in_version(tags, arch, api): - return - - if 'var' in tags: - src_file.write('int {} = 0;\n'.format(symbol_name)) - else: - src_file.write('void {}() {{}}\n'.format(symbol_name)) - version_file.write(line) + For example, some of libc's __aeabi_* functions were originally placed in + the private version, but that was incorrect. They are now in LIBC_N, but + when building against any version prior to N we need the symbol to be + unversioned (otherwise it won't resolve on M where it is private). + """ + for tag in tags: + if tag.startswith('versioned='): + return api >= int(get_tag_value(tag)) + # If there is no "versioned" tag, the tag has been versioned for as long as + # it was introduced. + return True -def generate(symbol_file, src_file, version_file, arch, api): - """Generates the stub source file and version script.""" - scope = Stack() - scope.push(Scope.Top) - for line in symbol_file: - if line.strip() == '' or line.strip().startswith('#'): - version_file.write(line) - elif scope.top == Scope.Top: - handle_top_scope(scope, line, version_file, arch, api) - elif scope.top == Scope.Private: - handle_private_scope(scope, line, version_file) - elif scope.top == Scope.Local: - handle_local_scope(scope, line, version_file) - elif scope.top == Scope.Global: - handle_global_scope(scope, line, src_file, version_file, arch, api) +class ParseError(RuntimeError): + """An error that occurred while parsing a symbol file.""" + pass + + +class Version(object): + """A version block of a symbol file.""" + def __init__(self, name, base, tags, symbols): + self.name = name + self.base = base + self.tags = tags + self.symbols = symbols + + def __eq__(self, other): + if self.name != other.name: + return False + if self.base != other.base: + return False + if self.tags != other.tags: + return False + if self.symbols != other.symbols: + return False + return True + + +class Symbol(object): + """A symbol definition from a symbol file.""" + def __init__(self, name, tags): + self.name = name + self.tags = tags + + def __eq__(self, other): + return self.name == other.name and set(self.tags) == set(other.tags) + + +class SymbolFileParser(object): + """Parses NDK symbol files.""" + def __init__(self, input_file): + self.input_file = input_file + self.current_line = None + + def parse(self): + """Parses the symbol file and returns a list of Version objects.""" + versions = [] + while self.next_line() != '': + if '{' in self.current_line: + versions.append(self.parse_version()) + else: + raise ParseError( + 'Unexpected contents at top level: ' + self.current_line) + return versions + + def parse_version(self): + """Parses a single version section and returns a Version object.""" + name = self.current_line.split('{')[0].strip() + tags = get_tags(self.current_line) + symbols = [] + global_scope = True + while self.next_line() != '': + if '}' in self.current_line: + # Line is something like '} BASE; # tags'. Both base and tags + # are optional here. + base = self.current_line.partition('}')[2] + base = base.partition('#')[0].strip() + if not base.endswith(';'): + raise ParseError( + 'Unterminated version block (expected ;).') + base = base.rstrip(';').rstrip() + if base == '': + base = None + return Version(name, base, tags, symbols) + elif ':' in self.current_line: + visibility = self.current_line.split(':')[0].strip() + if visibility == 'local': + global_scope = False + elif visibility == 'global': + global_scope = True + else: + raise ParseError('Unknown visiblity label: ' + visibility) + elif global_scope: + symbols.append(self.parse_symbol()) + else: + # We're in a hidden scope. Ignore everything. + pass + raise ParseError('Unexpected EOF in version block.') + + def parse_symbol(self): + """Parses a single symbol line and returns a Symbol object.""" + if ';' not in self.current_line: + raise ParseError( + 'Expected ; to terminate symbol: ' + self.current_line) + if '*' in self.current_line: + raise ParseError( + 'Wildcard global symbols are not permitted.') + # Line is now in the format "; # tags" + name, _, _ = self.current_line.strip().partition(';') + tags = get_tags(self.current_line) + return Symbol(name, tags) + + def next_line(self): + """Returns the next non-empty non-comment line. + + A return value of '' indicates EOF. + """ + line = self.input_file.readline() + while line.strip() == '' or line.strip().startswith('#'): + line = self.input_file.readline() + + # We want to skip empty lines, but '' indicates EOF. + if line == '': + break + self.current_line = line + return self.current_line + + +class Generator(object): + """Output generator that writes stub source files and version scripts.""" + def __init__(self, src_file, version_script, arch, api): + self.src_file = src_file + self.version_script = version_script + self.arch = arch + self.api = api + + def write(self, versions): + """Writes all symbol data to the output files.""" + for version in versions: + self.write_version(version) + + def write_version(self, version): + """Writes a single version block's data to the output files.""" + name = version.name + tags = version.tags + if should_omit_version(name, tags, self.arch, self.api): + return + + version_empty = True + pruned_symbols = [] + for symbol in version.symbols: + if not symbol_in_arch(symbol.tags, self.arch): + continue + if not symbol_in_api(symbol.tags, self.arch, self.api): + continue + + if symbol_versioned_in_api(symbol.tags, self.api): + version_empty = False + pruned_symbols.append(symbol) + + if len(pruned_symbols) > 0: + if not version_empty: + self.version_script.write(version.name + ' {\n') + self.version_script.write(' global:\n') + for symbol in pruned_symbols: + if symbol_versioned_in_api(symbol.tags, self.api): + self.version_script.write(' ' + symbol.name + ';\n') + + if 'var' in symbol.tags: + self.src_file.write('int {} = 0;\n'.format(symbol.name)) + else: + self.src_file.write('void {}() {{}}\n'.format(symbol.name)) + + if not version_empty: + base = '' if version.base is None else ' ' + version.base + self.version_script.write('}' + base + ';\n') def parse_args(): """Parses and returns command line arguments.""" parser = argparse.ArgumentParser() - parser.add_argument('--api', type=int, help='API level being targeted.') + parser.add_argument('-v', '--verbose', action='count', default=0) + parser.add_argument( - '--arch', choices=ALL_ARCHITECTURES, + '--api', type=int, required=True, help='API level being targeted.') + parser.add_argument( + '--arch', choices=ALL_ARCHITECTURES, required=True, help='Architecture being targeted.') parser.add_argument( @@ -278,11 +329,19 @@ def main(): """Program entry point.""" args = parse_args() + verbose_map = (logging.WARNING, logging.INFO, logging.DEBUG) + verbosity = args.verbose + if verbosity > 2: + verbosity = 2 + logging.basicConfig(level=verbose_map[verbosity]) + with open(args.symbol_file) as symbol_file: - with open(args.stub_src, 'w') as src_file: - with open(args.version_script, 'w') as version_file: - generate(symbol_file, src_file, version_file, args.arch, - args.api) + versions = SymbolFileParser(symbol_file).parse() + + with open(args.stub_src, 'w') as src_file: + with open(args.version_script, 'w') as version_file: + generator = Generator(src_file, version_file, args.arch, args.api) + generator.write(versions) if __name__ == '__main__': diff --git a/cc/test_gen_stub_libs.py b/cc/test_gen_stub_libs.py new file mode 100755 index 000000000..8436a4804 --- /dev/null +++ b/cc/test_gen_stub_libs.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python +# +# Copyright (C) 2016 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for gen_stub_libs.py.""" +import cStringIO +import textwrap +import unittest + +import gen_stub_libs as gsl + + +# pylint: disable=missing-docstring + + +class TagsTest(unittest.TestCase): + def test_get_tags_no_tags(self): + self.assertEqual([], gsl.get_tags('')) + self.assertEqual([], gsl.get_tags('foo bar baz')) + + def test_get_tags(self): + self.assertEqual(['foo', 'bar'], gsl.get_tags('# foo bar')) + self.assertEqual(['bar', 'baz'], gsl.get_tags('foo # bar baz')) + + def test_get_tag_value(self): + self.assertEqual('bar', gsl.get_tag_value('foo=bar')) + self.assertEqual('bar=baz', gsl.get_tag_value('foo=bar=baz')) + with self.assertRaises(ValueError): + gsl.get_tag_value('foo') + + +class PrivateVersionTest(unittest.TestCase): + def test_version_is_private(self): + self.assertFalse(gsl.version_is_private('foo')) + self.assertFalse(gsl.version_is_private('PRIVATE')) + self.assertFalse(gsl.version_is_private('PLATFORM')) + self.assertFalse(gsl.version_is_private('foo_private')) + self.assertFalse(gsl.version_is_private('foo_platform')) + self.assertFalse(gsl.version_is_private('foo_PRIVATE_')) + self.assertFalse(gsl.version_is_private('foo_PLATFORM_')) + + self.assertTrue(gsl.version_is_private('foo_PRIVATE')) + self.assertTrue(gsl.version_is_private('foo_PLATFORM')) + + +class SymbolPresenceTest(unittest.TestCase): + def test_symbol_in_arch(self): + self.assertTrue(gsl.symbol_in_arch([], 'arm')) + self.assertTrue(gsl.symbol_in_arch(['arm'], 'arm')) + + self.assertFalse(gsl.symbol_in_arch(['x86'], 'arm')) + + def test_symbol_in_api(self): + self.assertTrue(gsl.symbol_in_api([], 'arm', 9)) + self.assertTrue(gsl.symbol_in_api(['introduced=9'], 'arm', 9)) + self.assertTrue(gsl.symbol_in_api(['introduced=9'], 'arm', 14)) + self.assertTrue(gsl.symbol_in_api(['introduced-arm=9'], 'arm', 14)) + self.assertTrue(gsl.symbol_in_api(['introduced-arm=9'], 'arm', 14)) + self.assertTrue(gsl.symbol_in_api(['introduced-x86=14'], 'arm', 9)) + self.assertTrue(gsl.symbol_in_api( + ['introduced-arm=9', 'introduced-x86=21'], 'arm', 14)) + self.assertTrue(gsl.symbol_in_api( + ['introduced=9', 'introduced-x86=21'], 'arm', 14)) + self.assertTrue(gsl.symbol_in_api( + ['introduced=21', 'introduced-arm=9'], 'arm', 14)) + + self.assertFalse(gsl.symbol_in_api(['introduced=14'], 'arm', 9)) + self.assertFalse(gsl.symbol_in_api(['introduced-arm=14'], 'arm', 9)) + self.assertFalse(gsl.symbol_in_api(['future'], 'arm', 9)) + self.assertFalse(gsl.symbol_in_api( + ['introduced=9', 'future'], 'arm', 14)) + self.assertFalse(gsl.symbol_in_api( + ['introduced-arm=9', 'future'], 'arm', 14)) + self.assertFalse(gsl.symbol_in_api( + ['introduced-arm=21', 'introduced-x86=9'], 'arm', 14)) + self.assertFalse(gsl.symbol_in_api( + ['introduced=9', 'introduced-arm=21'], 'arm', 14)) + self.assertFalse(gsl.symbol_in_api( + ['introduced=21', 'introduced-x86=9'], 'arm', 14)) + + # Interesting edge case: this symbol should be omitted from the + # library, but this call should still return true because none of the + # tags indiciate that it's not present in this API level. + self.assertTrue(gsl.symbol_in_api(['x86'], 'arm', 9)) + + def test_verioned_in_api(self): + self.assertTrue(gsl.symbol_versioned_in_api([], 9)) + self.assertTrue(gsl.symbol_versioned_in_api(['versioned=9'], 9)) + self.assertTrue(gsl.symbol_versioned_in_api(['versioned=9'], 14)) + + self.assertFalse(gsl.symbol_versioned_in_api(['versioned=14'], 9)) + + +class OmitVersionTest(unittest.TestCase): + def test_omit_private(self): + self.assertFalse(gsl.should_omit_version('foo', [], 'arm', 9)) + + self.assertTrue(gsl.should_omit_version('foo_PRIVATE', [], 'arm', 9)) + self.assertTrue(gsl.should_omit_version('foo_PLATFORM', [], 'arm', 9)) + + def test_omit_arch(self): + self.assertFalse(gsl.should_omit_version('foo', [], 'arm', 9)) + self.assertFalse(gsl.should_omit_version('foo', ['arm'], 'arm', 9)) + + self.assertTrue(gsl.should_omit_version('foo', ['x86'], 'arm', 9)) + + def test_omit_api(self): + self.assertFalse(gsl.should_omit_version('foo', [], 'arm', 9)) + self.assertFalse( + gsl.should_omit_version('foo', ['introduced=9'], 'arm', 9)) + + self.assertTrue( + gsl.should_omit_version('foo', ['introduced=14'], 'arm', 9)) + + +class SymbolFileParseTest(unittest.TestCase): + def test_next_line(self): + input_file = cStringIO.StringIO(textwrap.dedent("""\ + foo + + bar + # baz + qux + """)) + parser = gsl.SymbolFileParser(input_file) + self.assertIsNone(parser.current_line) + + self.assertEqual('foo', parser.next_line().strip()) + self.assertEqual('foo', parser.current_line.strip()) + + self.assertEqual('bar', parser.next_line().strip()) + self.assertEqual('bar', parser.current_line.strip()) + + self.assertEqual('qux', parser.next_line().strip()) + self.assertEqual('qux', parser.current_line.strip()) + + self.assertEqual('', parser.next_line()) + self.assertEqual('', parser.current_line) + + def test_parse_version(self): + input_file = cStringIO.StringIO(textwrap.dedent("""\ + VERSION_1 { # foo bar + baz; + qux; # woodly doodly + }; + + VERSION_2 { + } VERSION_1; # asdf + """)) + parser = gsl.SymbolFileParser(input_file) + + parser.next_line() + version = parser.parse_version() + self.assertEqual('VERSION_1', version.name) + self.assertIsNone(version.base) + self.assertEqual(['foo', 'bar'], version.tags) + + expected_symbols = [ + gsl.Symbol('baz', []), + gsl.Symbol('qux', ['woodly', 'doodly']), + ] + self.assertEqual(expected_symbols, version.symbols) + + parser.next_line() + version = parser.parse_version() + self.assertEqual('VERSION_2', version.name) + self.assertEqual('VERSION_1', version.base) + self.assertEqual([], version.tags) + + def test_parse_version_eof(self): + input_file = cStringIO.StringIO(textwrap.dedent("""\ + VERSION_1 { + """)) + parser = gsl.SymbolFileParser(input_file) + parser.next_line() + with self.assertRaises(gsl.ParseError): + parser.parse_version() + + def test_unknown_scope_label(self): + input_file = cStringIO.StringIO(textwrap.dedent("""\ + VERSION_1 { + foo: + } + """)) + parser = gsl.SymbolFileParser(input_file) + parser.next_line() + with self.assertRaises(gsl.ParseError): + parser.parse_version() + + def test_parse_symbol(self): + input_file = cStringIO.StringIO(textwrap.dedent("""\ + foo; + bar; # baz qux + """)) + parser = gsl.SymbolFileParser(input_file) + + parser.next_line() + symbol = parser.parse_symbol() + self.assertEqual('foo', symbol.name) + self.assertEqual([], symbol.tags) + + parser.next_line() + symbol = parser.parse_symbol() + self.assertEqual('bar', symbol.name) + self.assertEqual(['baz', 'qux'], symbol.tags) + + def test_wildcard_symbol_global(self): + input_file = cStringIO.StringIO(textwrap.dedent("""\ + VERSION_1 { + *; + }; + """)) + parser = gsl.SymbolFileParser(input_file) + parser.next_line() + with self.assertRaises(gsl.ParseError): + parser.parse_version() + + def test_wildcard_symbol_local(self): + input_file = cStringIO.StringIO(textwrap.dedent("""\ + VERSION_1 { + local: + *; + }; + """)) + parser = gsl.SymbolFileParser(input_file) + parser.next_line() + version = parser.parse_version() + self.assertEqual([], version.symbols) + + def test_missing_semicolon(self): + input_file = cStringIO.StringIO(textwrap.dedent("""\ + VERSION_1 { + foo + }; + """)) + parser = gsl.SymbolFileParser(input_file) + parser.next_line() + with self.assertRaises(gsl.ParseError): + parser.parse_version() + + def test_parse_fails_invalid_input(self): + with self.assertRaises(gsl.ParseError): + input_file = cStringIO.StringIO('foo') + parser = gsl.SymbolFileParser(input_file) + parser.parse() + + def test_parse(self): + input_file = cStringIO.StringIO(textwrap.dedent("""\ + VERSION_1 { + local: + hidden1; + global: + foo; + bar; # baz + }; + + VERSION_2 { # wasd + # Implicit global scope. + woodly; + doodly; # asdf + local: + qwerty; + } VERSION_1; + """)) + parser = gsl.SymbolFileParser(input_file) + versions = parser.parse() + + expected = [ + gsl.Version('VERSION_1', None, [], [ + gsl.Symbol('foo', []), + gsl.Symbol('bar', ['baz']), + ]), + gsl.Version('VERSION_2', 'VERSION_1', ['wasd'], [ + gsl.Symbol('woodly', []), + gsl.Symbol('doodly', ['asdf']), + ]), + ] + + self.assertEqual(expected, versions) + + +class GeneratorTest(unittest.TestCase): + def test_omit_version(self): + # Thorough testing of the cases involved here is handled by + # OmitVersionTest, PrivateVersionTest, and SymbolPresenceTest. + src_file = cStringIO.StringIO() + version_file = cStringIO.StringIO() + generator = gsl.Generator(src_file, version_file, 'arm', 9) + + version = gsl.Version('VERSION_PRIVATE', None, [], [ + gsl.Symbol('foo', []), + ]) + generator.write_version(version) + self.assertEqual('', src_file.getvalue()) + self.assertEqual('', version_file.getvalue()) + + version = gsl.Version('VERSION', None, ['x86'], [ + gsl.Symbol('foo', []), + ]) + generator.write_version(version) + self.assertEqual('', src_file.getvalue()) + self.assertEqual('', version_file.getvalue()) + + version = gsl.Version('VERSION', None, ['introduced=14'], [ + gsl.Symbol('foo', []), + ]) + generator.write_version(version) + self.assertEqual('', src_file.getvalue()) + self.assertEqual('', version_file.getvalue()) + + def test_omit_symbol(self): + # Thorough testing of the cases involved here is handled by + # SymbolPresenceTest. + src_file = cStringIO.StringIO() + version_file = cStringIO.StringIO() + generator = gsl.Generator(src_file, version_file, 'arm', 9) + + version = gsl.Version('VERSION_1', None, [], [ + gsl.Symbol('foo', ['x86']), + ]) + generator.write_version(version) + self.assertEqual('', src_file.getvalue()) + self.assertEqual('', version_file.getvalue()) + + version = gsl.Version('VERSION_1', None, [], [ + gsl.Symbol('foo', ['introduced=14']), + ]) + generator.write_version(version) + self.assertEqual('', src_file.getvalue()) + self.assertEqual('', version_file.getvalue()) + + def test_write(self): + src_file = cStringIO.StringIO() + version_file = cStringIO.StringIO() + generator = gsl.Generator(src_file, version_file, 'arm', 9) + + versions = [ + gsl.Version('VERSION_1', None, [], [ + gsl.Symbol('foo', []), + gsl.Symbol('bar', ['var']), + ]), + gsl.Version('VERSION_2', 'VERSION_1', [], [ + gsl.Symbol('baz', []), + ]), + gsl.Version('VERSION_3', 'VERSION_1', [], [ + gsl.Symbol('qux', ['versioned=14']), + ]), + ] + + generator.write(versions) + expected_src = textwrap.dedent("""\ + void foo() {} + int bar = 0; + void baz() {} + void qux() {} + """) + self.assertEqual(expected_src, src_file.getvalue()) + + expected_version = textwrap.dedent("""\ + VERSION_1 { + global: + foo; + bar; + }; + VERSION_2 { + global: + baz; + } VERSION_1; + """) + self.assertEqual(expected_version, version_file.getvalue()) + + +class IntegrationTest(unittest.TestCase): + def test_integration(self): + input_file = cStringIO.StringIO(textwrap.dedent("""\ + VERSION_1 { + global: + foo; # var + bar; # x86 + local: + *; + }; + + VERSION_2 { # arm + baz; # introduced=9 + qux; # versioned=14 + } VERSION_1; + + VERSION_3 { # introduced=14 + woodly; + doodly; # var + } VERSION_2; + """)) + parser = gsl.SymbolFileParser(input_file) + versions = parser.parse() + + src_file = cStringIO.StringIO() + version_file = cStringIO.StringIO() + generator = gsl.Generator(src_file, version_file, 'arm', 9) + generator.write(versions) + + expected_src = textwrap.dedent("""\ + int foo = 0; + void baz() {} + void qux() {} + """) + self.assertEqual(expected_src, src_file.getvalue()) + + expected_version = textwrap.dedent("""\ + VERSION_1 { + global: + foo; + }; + VERSION_2 { + global: + baz; + } VERSION_1; + """) + self.assertEqual(expected_version, version_file.getvalue()) + + +def main(): + suite = unittest.TestLoader().loadTestsFromName(__name__) + unittest.TextTestRunner(verbosity=3).run(suite) + + +if __name__ == '__main__': + main()