diff --git a/tools/aconfig/src/codegen_cpp.rs b/tools/aconfig/src/codegen_cpp.rs index 2944e8a2aa..a802725841 100644 --- a/tools/aconfig/src/codegen_cpp.rs +++ b/tools/aconfig/src/codegen_cpp.rs @@ -16,13 +16,18 @@ use anyhow::{ensure, Result}; use serde::Serialize; +use std::path::PathBuf; use tinytemplate::TinyTemplate; use crate::codegen; -use crate::commands::OutputFile; +use crate::commands::{CodegenMode, OutputFile}; use crate::protos::{ProtoFlagPermission, ProtoFlagState, ProtoParsedFlag}; -pub fn generate_cpp_code<'a, I>(package: &str, parsed_flags_iter: I) -> Result +pub fn generate_cpp_code<'a, I>( + package: &str, + parsed_flags_iter: I, + codegen_mode: CodegenMode, +) -> Result> where I: Iterator, { @@ -37,29 +42,66 @@ where cpp_namespace, package: package.to_string(), readwrite, + for_prod: codegen_mode == CodegenMode::Production, class_elements, }; + + let files = [ + FileSpec { + name: &format!("{}.h", header), + template: include_str!("../templates/cpp_exported_header.template"), + dir: "include", + }, + FileSpec { + name: &format!("{}.cc", header), + template: include_str!("../templates/cpp_source_file.template"), + dir: "", + }, + FileSpec { + name: &format!("{}_flag_provider.h", header), + template: match codegen_mode { + CodegenMode::Production => { + include_str!("../templates/cpp_prod_flag_provider.template") + } + CodegenMode::Test => include_str!("../templates/cpp_test_flag_provider.template"), + }, + dir: "", + }, + ]; + files.iter().map(|file| generate_file(file, &context)).collect() +} + +pub fn generate_file(file: &FileSpec, context: &Context) -> Result { let mut template = TinyTemplate::new(); - template.add_template("cpp_code_gen", include_str!("../templates/cpp.template"))?; - let contents = template.render("cpp_code_gen", &context)?; - let path = ["aconfig", &(header + ".h")].iter().collect(); + template.add_template(file.name, file.template)?; + let contents = template.render(file.name, &context)?; + let path: PathBuf = [&file.dir, &file.name].iter().collect(); Ok(OutputFile { contents: contents.into(), path }) } #[derive(Serialize)] -struct Context { +pub struct FileSpec<'a> { + pub name: &'a str, + pub template: &'a str, + pub dir: &'a str, +} + +#[derive(Serialize)] +pub struct Context { pub header: String, pub cpp_namespace: String, pub package: String, pub readwrite: bool, + pub for_prod: bool, pub class_elements: Vec, } #[derive(Serialize)] -struct ClassElement { +pub struct ClassElement { pub readwrite: bool, pub default_value: String, pub flag_name: String, + pub uppercase_flag_name: String, pub device_config_namespace: String, pub device_config_flag: String, } @@ -73,6 +115,7 @@ fn create_class_element(package: &str, pf: &ProtoParsedFlag) -> ClassElement { "false".to_string() }, flag_name: pf.name().to_string(), + uppercase_flag_name: pf.name().to_string().to_ascii_uppercase(), device_config_namespace: pf.namespace().to_string(), device_config_flag: codegen::create_device_config_ident(package, pf.name()) .expect("values checked at flag parse time"), @@ -82,51 +125,325 @@ fn create_class_element(package: &str, pf: &ProtoParsedFlag) -> ClassElement { #[cfg(test)] mod tests { use super::*; + use std::collections::HashMap; - #[test] - fn test_generate_cpp_code() { - let parsed_flags = crate::test::parse_test_flags(); - let generated = - generate_cpp_code(crate::test::TEST_PACKAGE, parsed_flags.parsed_flag.iter()).unwrap(); - assert_eq!("aconfig/com_android_aconfig_test.h", format!("{}", generated.path.display())); - let expected = r#" + const EXPORTED_PROD_HEADER_EXPECTED: &str = r#" #ifndef com_android_aconfig_test_HEADER_H #define com_android_aconfig_test_HEADER_H -#include +#include +#include +#include using namespace server_configurable_flags; namespace com::android::aconfig::test { - static const bool disabled_ro() { +class flag_provider_interface { +public: + + virtual ~flag_provider_interface() = default; + + virtual bool disabled_ro() = 0; + + virtual bool disabled_rw() = 0; + + virtual bool enabled_ro() = 0; + + virtual bool enabled_rw() = 0; + + virtual void override_flag(std::string const&, bool) {} + + virtual void reset_overrides() {} +}; + +extern std::unique_ptr provider_; + +extern std::string const DISABLED_RO; +extern std::string const DISABLED_RW; +extern std::string const ENABLED_RO; +extern std::string const ENABLED_RW; + +inline bool disabled_ro() { + return false; +} + +inline bool disabled_rw() { + return provider_->disabled_rw(); +} + +inline bool enabled_ro() { + return true; +} + +inline bool enabled_rw() { + return provider_->enabled_rw(); +} + +inline void override_flag(std::string const& name, bool val) { + return provider_->override_flag(name, val); +} + +inline void reset_overrides() { + return provider_->reset_overrides(); +} + +} +#endif +"#; + + const EXPORTED_TEST_HEADER_EXPECTED: &str = r#" +#ifndef com_android_aconfig_test_HEADER_H +#define com_android_aconfig_test_HEADER_H + +#include +#include +#include +using namespace server_configurable_flags; + +namespace com::android::aconfig::test { +class flag_provider_interface { +public: + + virtual ~flag_provider_interface() = default; + + virtual bool disabled_ro() = 0; + + virtual bool disabled_rw() = 0; + + virtual bool enabled_ro() = 0; + + virtual bool enabled_rw() = 0; + + virtual void override_flag(std::string const&, bool) {} + + virtual void reset_overrides() {} +}; + +extern std::unique_ptr provider_; + +extern std::string const DISABLED_RO; +extern std::string const DISABLED_RW; +extern std::string const ENABLED_RO; +extern std::string const ENABLED_RW; + +inline bool disabled_ro() { + return provider_->disabled_ro(); +} + +inline bool disabled_rw() { + return provider_->disabled_rw(); +} + +inline bool enabled_ro() { + return provider_->enabled_ro(); +} + +inline bool enabled_rw() { + return provider_->enabled_rw(); +} + +inline void override_flag(std::string const& name, bool val) { + return provider_->override_flag(name, val); +} + +inline void reset_overrides() { + return provider_->reset_overrides(); +} + +} +#endif +"#; + + const PROD_FLAG_PROVIDER_HEADER_EXPECTED: &str = r#" +#ifndef com_android_aconfig_test_flag_provider_HEADER_H +#define com_android_aconfig_test_flag_provider_HEADER_H + +#include "com_android_aconfig_test.h" + +namespace com::android::aconfig::test { +class flag_provider : public flag_provider_interface { +public: + + virtual bool disabled_ro() override { return false; } - static const bool disabled_rw() { + virtual bool disabled_rw() override { return GetServerConfigurableFlag( "aconfig_test", "com.android.aconfig.test.disabled_rw", "false") == "true"; } - static const bool enabled_ro() { + virtual bool enabled_ro() override { return true; } - static const bool enabled_rw() { + virtual bool enabled_rw() override { return GetServerConfigurableFlag( "aconfig_test", "com.android.aconfig.test.enabled_rw", "true") == "true"; } +}; } #endif "#; + + const TEST_FLAG_PROVIDER_HEADER_EXPECTED: &str = r#" +#ifndef com_android_aconfig_test_flag_provider_HEADER_H +#define com_android_aconfig_test_flag_provider_HEADER_H + +#include "com_android_aconfig_test.h" + +#include +#include +#include + +namespace com::android::aconfig::test { +class flag_provider : public flag_provider_interface { +private: + std::unordered_map overrides_; + std::unordered_set flag_names_; + +public: + + flag_provider() + : overrides_(), + flag_names_() { + flag_names_.insert(DISABLED_RO); + flag_names_.insert(DISABLED_RW); + flag_names_.insert(ENABLED_RO); + flag_names_.insert(ENABLED_RW); + } + + virtual bool disabled_ro() override { + auto it = overrides_.find(DISABLED_RO); + if (it != overrides_.end()) { + return it->second; + } else { + return false; + } + } + + virtual bool disabled_rw() override { + auto it = overrides_.find(DISABLED_RW); + if (it != overrides_.end()) { + return it->second; + } else { + return GetServerConfigurableFlag( + "aconfig_test", + "com.android.aconfig.test.disabled_rw", + "false") == "true"; + } + } + + virtual bool enabled_ro() override { + auto it = overrides_.find(ENABLED_RO); + if (it != overrides_.end()) { + return it->second; + } else { + return true; + } + } + + virtual bool enabled_rw() override { + auto it = overrides_.find(ENABLED_RW); + if (it != overrides_.end()) { + return it->second; + } else { + return GetServerConfigurableFlag( + "aconfig_test", + "com.android.aconfig.test.enabled_rw", + "true") == "true"; + } + } + + virtual void override_flag(std::string const& flag, bool val) override { + assert(flag_names_.count(flag)); + overrides_[flag] = val; + } + + virtual void reset_overrides() override { + overrides_.clear(); + } +}; +} +#endif +"#; + + const SOURCE_FILE_EXPECTED: &str = r#" +#include "com_android_aconfig_test.h" +#include "com_android_aconfig_test_flag_provider.h" + +namespace com::android::aconfig::test { + + std::string const DISABLED_RO = "com.android.aconfig.test.disabled_ro"; + std::string const DISABLED_RW = "com.android.aconfig.test.disabled_rw"; + std::string const ENABLED_RO = "com.android.aconfig.test.enabled_ro"; + std::string const ENABLED_RW = "com.android.aconfig.test.enabled_rw"; + + std::unique_ptr provider_ = + std::make_unique(); +} +"#; + + fn test_generate_cpp_code(mode: CodegenMode) { + let parsed_flags = crate::test::parse_test_flags(); + let generated = + generate_cpp_code(crate::test::TEST_PACKAGE, parsed_flags.parsed_flag.iter(), mode) + .unwrap(); + let mut generated_files_map = HashMap::new(); + for file in generated { + generated_files_map.insert( + String::from(file.path.to_str().unwrap()), + String::from_utf8(file.contents.clone()).unwrap(), + ); + } + + let mut target_file_path = String::from("include/com_android_aconfig_test.h"); + assert!(generated_files_map.contains_key(&target_file_path)); assert_eq!( None, crate::test::first_significant_code_diff( - expected, - &String::from_utf8(generated.contents).unwrap() + match mode { + CodegenMode::Production => EXPORTED_PROD_HEADER_EXPECTED, + CodegenMode::Test => EXPORTED_TEST_HEADER_EXPECTED, + }, + generated_files_map.get(&target_file_path).unwrap() + ) + ); + + target_file_path = String::from("com_android_aconfig_test_flag_provider.h"); + assert!(generated_files_map.contains_key(&target_file_path)); + assert_eq!( + None, + crate::test::first_significant_code_diff( + match mode { + CodegenMode::Production => PROD_FLAG_PROVIDER_HEADER_EXPECTED, + CodegenMode::Test => TEST_FLAG_PROVIDER_HEADER_EXPECTED, + }, + generated_files_map.get(&target_file_path).unwrap() + ) + ); + + target_file_path = String::from("com_android_aconfig_test.cc"); + assert!(generated_files_map.contains_key(&target_file_path)); + assert_eq!( + None, + crate::test::first_significant_code_diff( + SOURCE_FILE_EXPECTED, + generated_files_map.get(&target_file_path).unwrap() ) ); } + + #[test] + fn test_generate_cpp_code_for_prod() { + test_generate_cpp_code(CodegenMode::Production); + } + + #[test] + fn test_generate_cpp_code_for_test() { + test_generate_cpp_code(CodegenMode::Test); + } } diff --git a/tools/aconfig/src/commands.rs b/tools/aconfig/src/commands.rs index dd2087b3c3..687f3195cf 100644 --- a/tools/aconfig/src/commands.rs +++ b/tools/aconfig/src/commands.rs @@ -143,12 +143,12 @@ pub fn create_java_lib(mut input: Input, codegen_mode: CodegenMode) -> Result Result { +pub fn create_cpp_lib(mut input: Input, codegen_mode: CodegenMode) -> Result> { let parsed_flags = input.try_parse_flags()?; let Some(package) = find_unique_package(&parsed_flags) else { bail!("no parsed flags, or the parsed flags use different packages"); }; - generate_cpp_code(package, parsed_flags.parsed_flag.iter()) + generate_cpp_code(package, parsed_flags.parsed_flag.iter(), codegen_mode) } pub fn create_rust_lib(mut input: Input) -> Result { diff --git a/tools/aconfig/src/main.rs b/tools/aconfig/src/main.rs index e20c60cd2c..72feb9406f 100644 --- a/tools/aconfig/src/main.rs +++ b/tools/aconfig/src/main.rs @@ -60,7 +60,13 @@ fn cli() -> Command { .subcommand( Command::new("create-cpp-lib") .arg(Arg::new("cache").long("cache").required(true)) - .arg(Arg::new("out").long("out").required(true)), + .arg(Arg::new("out").long("out").required(true)) + .arg( + Arg::new("mode") + .long("mode") + .value_parser(EnumValueParser::::new()) + .default_value("production"), + ), ) .subcommand( Command::new("create-rust-lib") @@ -163,9 +169,12 @@ fn main() -> Result<()> { } Some(("create-cpp-lib", sub_matches)) => { let cache = open_single_file(sub_matches, "cache")?; - let generated_file = commands::create_cpp_lib(cache)?; + let mode = get_required_arg::(sub_matches, "mode")?; + let generated_files = commands::create_cpp_lib(cache, *mode)?; let dir = PathBuf::from(get_required_arg::(sub_matches, "out")?); - write_output_file_realtive_to_dir(&dir, &generated_file)?; + generated_files + .iter() + .try_for_each(|file| write_output_file_realtive_to_dir(&dir, file))?; } Some(("create-rust-lib", sub_matches)) => { let cache = open_single_file(sub_matches, "cache")?; diff --git a/tools/aconfig/templates/cpp_exported_header.template b/tools/aconfig/templates/cpp_exported_header.template new file mode 100644 index 0000000000..e244de3d62 --- /dev/null +++ b/tools/aconfig/templates/cpp_exported_header.template @@ -0,0 +1,48 @@ +#ifndef {header}_HEADER_H +#define {header}_HEADER_H + +#include +#include +{{ if readwrite }} +#include +using namespace server_configurable_flags; +{{ endif }} +namespace {cpp_namespace} \{ + +class flag_provider_interface \{ +public: + virtual ~flag_provider_interface() = default; + {{ for item in class_elements}} + virtual bool {item.flag_name}() = 0; + {{ endfor }} + virtual void override_flag(std::string const&, bool) \{} + + virtual void reset_overrides() \{} +}; + +extern std::unique_ptr provider_; +{{ for item in class_elements}} +extern std::string const {item.uppercase_flag_name};{{ endfor }} +{{ for item in class_elements}} +inline bool {item.flag_name}() \{ + {{ if for_prod }} + {{ if not item.readwrite- }} + return {item.default_value}; + {{ -else- }} + return provider_->{item.flag_name}(); + {{ -endif }} + {{ -else- }} + return provider_->{item.flag_name}(); + {{ -endif }} +} +{{ endfor }} +inline void override_flag(std::string const& name, bool val) \{ + return provider_->override_flag(name, val); +} + +inline void reset_overrides() \{ + return provider_->reset_overrides(); +} + +} +#endif diff --git a/tools/aconfig/templates/cpp.template b/tools/aconfig/templates/cpp_prod_flag_provider.template similarity index 63% rename from tools/aconfig/templates/cpp.template rename to tools/aconfig/templates/cpp_prod_flag_provider.template index aa36d94ce4..c966ed4eb4 100644 --- a/tools/aconfig/templates/cpp.template +++ b/tools/aconfig/templates/cpp_prod_flag_provider.template @@ -1,12 +1,12 @@ -#ifndef {header}_HEADER_H -#define {header}_HEADER_H -{{ if readwrite }} -#include -using namespace server_configurable_flags; -{{ endif }} +#ifndef {header}_flag_provider_HEADER_H +#define {header}_flag_provider_HEADER_H +#include "{header}.h" + namespace {cpp_namespace} \{ +class flag_provider : public flag_provider_interface \{ +public: {{ for item in class_elements}} - static const bool {item.flag_name}() \{ + virtual bool {item.flag_name}() override \{ {{ if item.readwrite- }} return GetServerConfigurableFlag( "{item.device_config_namespace}", @@ -17,5 +17,6 @@ namespace {cpp_namespace} \{ {{ -endif }} } {{ endfor }} +}; } #endif diff --git a/tools/aconfig/templates/cpp_source_file.template b/tools/aconfig/templates/cpp_source_file.template new file mode 100644 index 0000000000..1b4f336707 --- /dev/null +++ b/tools/aconfig/templates/cpp_source_file.template @@ -0,0 +1,10 @@ + +#include "{header}.h" +#include "{header}_flag_provider.h" + +namespace {cpp_namespace} \{ +{{ for item in class_elements}} +std::string const {item.uppercase_flag_name} = "{item.device_config_flag}";{{ endfor }} +std::unique_ptr provider_ = + std::make_unique(); +} diff --git a/tools/aconfig/templates/cpp_test_flag_provider.template b/tools/aconfig/templates/cpp_test_flag_provider.template new file mode 100644 index 0000000000..bd597e7722 --- /dev/null +++ b/tools/aconfig/templates/cpp_test_flag_provider.template @@ -0,0 +1,49 @@ +#ifndef {header}_flag_provider_HEADER_H +#define {header}_flag_provider_HEADER_H +#include "{header}.h" + +#include +#include +#include + +namespace {cpp_namespace} \{ +class flag_provider : public flag_provider_interface \{ +private: + std::unordered_map overrides_; + std::unordered_set flag_names_; + +public: + flag_provider() + : overrides_(), + flag_names_() \{ + {{ for item in class_elements}} + flag_names_.insert({item.uppercase_flag_name});{{ endfor }} + } + {{ for item in class_elements}} + virtual bool {item.flag_name}() override \{ + auto it = overrides_.find({item.uppercase_flag_name}); + if (it != overrides_.end()) \{ + return it->second; + } else \{ + {{ if item.readwrite- }} + return GetServerConfigurableFlag( + "{item.device_config_namespace}", + "{item.device_config_flag}", + "{item.default_value}") == "true"; + {{ -else- }} + return {item.default_value}; + {{ -endif }} + } + } + {{ endfor }} + virtual void override_flag(std::string const& flag, bool val) override \{ + assert(flag_names_.count(flag)); + overrides_[flag] = val; + } + + virtual void reset_overrides() override \{ + overrides_.clear(); + } +}; +} +#endif