diff --git a/tools/aconfig/src/codegen_rust.rs b/tools/aconfig/src/codegen_rust.rs index f931418a7b..45303ce5f8 100644 --- a/tools/aconfig/src/codegen_rust.rs +++ b/tools/aconfig/src/codegen_rust.rs @@ -19,10 +19,14 @@ use serde::Serialize; use tinytemplate::TinyTemplate; use crate::codegen; -use crate::commands::OutputFile; +use crate::commands::{CodegenMode, OutputFile}; use crate::protos::{ProtoFlagPermission, ProtoFlagState, ProtoParsedFlag}; -pub fn generate_rust_code<'a, I>(package: &str, parsed_flags_iter: I) -> Result +pub fn generate_rust_code<'a, I>( + package: &str, + parsed_flags_iter: I, + codegen_mode: CodegenMode, +) -> Result where I: Iterator, { @@ -34,7 +38,13 @@ where modules: package.split('.').map(|s| s.to_string()).collect::>(), }; let mut template = TinyTemplate::new(); - template.add_template("rust_code_gen", include_str!("../templates/rust.template"))?; + template.add_template( + "rust_code_gen", + match codegen_mode { + CodegenMode::Production => include_str!("../templates/rust_prod.template"), + CodegenMode::Test => include_str!("../templates/rust_test.template"), + }, + )?; let contents = template.render("rust_code_gen", &context)?; let path = ["src", "lib.rs"].iter().collect(); Ok(OutputFile { contents: contents.into(), path }) @@ -49,41 +59,27 @@ struct TemplateContext { #[derive(Serialize)] struct TemplateParsedFlag { + pub readwrite: bool, + pub default_value: String, pub name: String, pub device_config_namespace: String, pub device_config_flag: String, - - // TinyTemplate's conditionals are limited to single expressions; list all options here - // Invariant: exactly one of these fields will be true - pub is_read_only_enabled: bool, - pub is_read_only_disabled: bool, - pub is_read_write: bool, } impl TemplateParsedFlag { #[allow(clippy::nonminimal_bool)] fn new(package: &str, pf: &ProtoParsedFlag) -> Self { let template = TemplateParsedFlag { + readwrite: pf.permission() == ProtoFlagPermission::READ_WRITE, + default_value: match pf.state() { + ProtoFlagState::ENABLED => "true".to_string(), + ProtoFlagState::DISABLED => "false".to_string(), + }, name: pf.name().to_string(), device_config_namespace: pf.namespace().to_string(), device_config_flag: codegen::create_device_config_ident(package, pf.name()) .expect("values checked at flag parse time"), - is_read_only_enabled: pf.permission() == ProtoFlagPermission::READ_ONLY - && pf.state() == ProtoFlagState::ENABLED, - is_read_only_disabled: pf.permission() == ProtoFlagPermission::READ_ONLY - && pf.state() == ProtoFlagState::DISABLED, - is_read_write: pf.permission() == ProtoFlagPermission::READ_WRITE, }; - #[rustfmt::skip] - debug_assert!( - (template.is_read_only_enabled && !template.is_read_only_disabled && !template.is_read_write) || - (!template.is_read_only_enabled && template.is_read_only_disabled && !template.is_read_write) || - (!template.is_read_only_enabled && !template.is_read_only_disabled && template.is_read_write), - "TemplateParsedFlag invariant failed: {} {} {}", - template.is_read_only_enabled, - template.is_read_only_disabled, - template.is_read_write, - ); template } } @@ -92,48 +88,224 @@ impl TemplateParsedFlag { mod tests { use super::*; - #[test] - fn test_generate_rust_code() { - let parsed_flags = crate::test::parse_test_flags(); - let generated = - generate_rust_code(crate::test::TEST_PACKAGE, parsed_flags.parsed_flag.iter()).unwrap(); - assert_eq!("src/lib.rs", format!("{}", generated.path.display())); - let expected = r#" -pub mod com { -pub mod android { -pub mod aconfig { -pub mod test { + const PROD_EXPECTED: &str = r#" +//! codegenerated rust flag lib + +/// flag provider +pub struct FlagProvider + +impl FlagProvider { + /// query flag disabled_ro + pub fn disabled_ro(&self) -> bool { + false + } + + /// query flag disabled_rw + pub fn disabled_rw(&self) -> bool { + flags_rust::GetServerConfigurableFlag( + "aconfig_test", + "com.android.aconfig.test.disabled_rw", + "false") == "true" + } + + /// query flag enabled_ro + pub fn enabled_ro(&self) -> bool { + true + } + + /// query flag enabled_rw + pub fn enabled_rw(&self) -> bool { + flags_rust::GetServerConfigurableFlag( + "aconfig_test", + "com.android.aconfig.test.enabled_rw", + "true") == "true" + } +} + +/// flag provider +pub static PROVIDER: FlagProvider = FlagProvider; + +/// query flag disabled_ro #[inline(always)] -pub const fn r#disabled_ro() -> bool { +pub fn disabled_ro() -> bool { false } +/// query flag disabled_rw #[inline(always)] -pub fn r#disabled_rw() -> bool { - flags_rust::GetServerConfigurableFlag("aconfig_test", "com.android.aconfig.test.disabled_rw", "false") == "true" +pub fn disabled_rw() -> bool { + PROVIDER.disabled_rw() } +/// query flag enabled_ro #[inline(always)] -pub const fn r#enabled_ro() -> bool { +pub fn enabled_ro() -> bool { true } +/// query flag enabled_rw #[inline(always)] -pub fn r#enabled_rw() -> bool { - flags_rust::GetServerConfigurableFlag("aconfig_test", "com.android.aconfig.test.enabled_rw", "false") == "true" -} - -} -} -} +pub fn enabled_rw() -> bool { + PROVIDER.enabled_rw() } "#; + + const TEST_EXPECTED: &str = r#" +//! codegenerated rust flag lib + +use std::collections::BTreeMap; +use std::sync::Mutex; + +/// flag provider +pub struct FlagProvider { + overrides: BTreeMap<&'static str, bool>, +} + +impl FlagProvider { + /// query flag disabled_ro + pub fn disabled_ro(&self) -> bool { + self.overrides.get("disabled_ro").copied().unwrap_or( + false + ) + } + + /// set flag disabled_ro + pub fn set_disabled_ro(&mut self, val: bool) { + self.overrides.insert("disabled_ro", val); + } + + /// query flag disabled_rw + pub fn disabled_rw(&self) -> bool { + self.overrides.get("disabled_rw").copied().unwrap_or( + flags_rust::GetServerConfigurableFlag( + "aconfig_test", + "com.android.aconfig.test.disabled_rw", + "false") == "true" + ) + } + + /// set flag disabled_rw + pub fn set_disabled_rw(&mut self, val: bool) { + self.overrides.insert("disabled_rw", val); + } + + /// query flag enabled_ro + pub fn enabled_ro(&self) -> bool { + self.overrides.get("enabled_ro").copied().unwrap_or( + true + ) + } + + /// set flag enabled_ro + pub fn set_enabled_ro(&mut self, val: bool) { + self.overrides.insert("enabled_ro", val); + } + + /// query flag enabled_rw + pub fn enabled_rw(&self) -> bool { + self.overrides.get("enabled_rw").copied().unwrap_or( + flags_rust::GetServerConfigurableFlag( + "aconfig_test", + "com.android.aconfig.test.enabled_rw", + "true") == "true" + ) + } + + /// set flag enabled_rw + pub fn set_enabled_rw(&mut self, val: bool) { + self.overrides.insert("enabled_rw", val); + } + + /// clear all flag overrides + pub fn reset_flags(&mut self) { + self.overrides.clear(); + } +} + +/// flag provider +pub static PROVIDER: Mutex = Mutex::new( + FlagProvider {overrides: BTreeMap::new()} +); + +/// query flag disabled_ro +#[inline(always)] +pub fn disabled_ro() -> bool { + PROVIDER.lock().unwrap().disabled_ro() +} + +/// set flag disabled_ro +#[inline(always)] +pub fn set_disabled_ro(val: bool) { + PROVIDER.lock().unwrap().set_disabled_ro(val); +} + +/// query flag disabled_rw +#[inline(always)] +pub fn disabled_rw() -> bool { + PROVIDER.lock().unwrap().disabled_rw() +} + +/// set flag disabled_rw +#[inline(always)] +pub fn set_disabled_rw(val: bool) { + PROVIDER.lock().unwrap().set_disabled_rw(val); +} + +/// query flag enabled_ro +#[inline(always)] +pub fn enabled_ro() -> bool { + PROVIDER.lock().unwrap().enabled_ro() +} + +/// set flag enabled_ro +#[inline(always)] +pub fn set_enabled_ro(val: bool) { + PROVIDER.lock().unwrap().set_enabled_ro(val); +} + +/// query flag enabled_rw +#[inline(always)] +pub fn enabled_rw() -> bool { + PROVIDER.lock().unwrap().enabled_rw() +} + +/// set flag enabled_rw +#[inline(always)] +pub fn set_enabled_rw(val: bool) { + PROVIDER.lock().unwrap().set_enabled_rw(val); +} + +/// clear all flag override +pub fn reset_flags() { + PROVIDER.lock().unwrap().reset_flags() +} +"#; + + fn test_generate_rust_code(mode: CodegenMode) { + let parsed_flags = crate::test::parse_test_flags(); + let generated = + generate_rust_code(crate::test::TEST_PACKAGE, parsed_flags.parsed_flag.iter(), mode) + .unwrap(); + assert_eq!("src/lib.rs", format!("{}", generated.path.display())); assert_eq!( None, crate::test::first_significant_code_diff( - expected, + match mode { + CodegenMode::Production => PROD_EXPECTED, + CodegenMode::Test => TEST_EXPECTED, + }, &String::from_utf8(generated.contents).unwrap() ) ); } + + #[test] + fn test_generate_rust_code_for_prod() { + test_generate_rust_code(CodegenMode::Production); + } + + #[test] + fn test_generate_rust_code_for_test() { + test_generate_rust_code(CodegenMode::Test); + } } diff --git a/tools/aconfig/src/commands.rs b/tools/aconfig/src/commands.rs index 687f3195cf..0ac84b24a2 100644 --- a/tools/aconfig/src/commands.rs +++ b/tools/aconfig/src/commands.rs @@ -108,7 +108,11 @@ pub fn parse_flags(package: &str, declarations: Vec, values: Vec) crate::protos::flag_value::verify_fields(&flag_value) .with_context(|| format!("Failed to parse {}", input.source))?; - let Some(parsed_flag) = parsed_flags.parsed_flag.iter_mut().find(|pf| pf.package() == flag_value.package() && pf.name() == flag_value.name()) else { + let Some(parsed_flag) = parsed_flags + .parsed_flag + .iter_mut() + .find(|pf| pf.package() == flag_value.package() && pf.name() == flag_value.name()) + else { // (silently) skip unknown flags continue; }; @@ -151,12 +155,12 @@ pub fn create_cpp_lib(mut input: Input, codegen_mode: CodegenMode) -> Result Result { +pub fn create_rust_lib(mut input: Input, codegen_mode: CodegenMode) -> Result { let parsed_flags = input.try_parse_flags()?; let Some(package) = find_unique_package(&parsed_flags) else { bail!("no parsed flags, or the parsed flags use different packages"); }; - generate_rust_code(package, parsed_flags.parsed_flag.iter()) + generate_rust_code(package, parsed_flags.parsed_flag.iter(), codegen_mode) } pub fn create_device_config_defaults(mut input: Input) -> Result> { diff --git a/tools/aconfig/src/main.rs b/tools/aconfig/src/main.rs index 72feb9406f..151cbe8f2a 100644 --- a/tools/aconfig/src/main.rs +++ b/tools/aconfig/src/main.rs @@ -71,7 +71,13 @@ fn cli() -> Command { .subcommand( Command::new("create-rust-lib") .arg(Arg::new("cache").long("cache").required(true)) - .arg(Arg::new("out").long("out").required(true)), + .arg(Arg::new("out").long("out").required(true)) + .arg( + Arg::new("mode") + .long("mode") + .value_parser(EnumValueParser::::new()) + .default_value("production"), + ), ) .subcommand( Command::new("create-device-config-defaults") @@ -178,7 +184,8 @@ fn main() -> Result<()> { } Some(("create-rust-lib", sub_matches)) => { let cache = open_single_file(sub_matches, "cache")?; - let generated_file = commands::create_rust_lib(cache)?; + let mode = get_required_arg::(sub_matches, "mode")?; + let generated_file = commands::create_rust_lib(cache, *mode)?; let dir = PathBuf::from(get_required_arg::(sub_matches, "out")?); write_output_file_realtive_to_dir(&dir, &generated_file)?; } diff --git a/tools/aconfig/templates/rust.template b/tools/aconfig/templates/rust.template deleted file mode 100644 index 960c494942..0000000000 --- a/tools/aconfig/templates/rust.template +++ /dev/null @@ -1,29 +0,0 @@ -{{- for mod in modules -}} -pub mod {mod} \{ -{{ endfor -}} -{{- for flag in template_flags -}} -{{- if flag.is_read_only_disabled -}} -#[inline(always)] -pub const fn r#{flag.name}() -> bool \{ - false -} - -{{ endif -}} -{{- if flag.is_read_only_enabled -}} -#[inline(always)] -pub const fn r#{flag.name}() -> bool \{ - true -} - -{{ endif -}} -{{- if flag.is_read_write -}} -#[inline(always)] -pub fn r#{flag.name}() -> bool \{ - flags_rust::GetServerConfigurableFlag("{flag.device_config_namespace}", "{flag.device_config_flag}", "false") == "true" -} - -{{ endif -}} -{{- endfor -}} -{{- for mod in modules -}} -} -{{ endfor -}} diff --git a/tools/aconfig/templates/rust_prod.template b/tools/aconfig/templates/rust_prod.template new file mode 100644 index 0000000000..543107e8bd --- /dev/null +++ b/tools/aconfig/templates/rust_prod.template @@ -0,0 +1,38 @@ +//! codegenerated rust flag lib + +/// flag provider +pub struct FlagProvider + +impl FlagProvider \{ + + {{ for flag in template_flags }} + /// query flag {flag.name} + pub fn {flag.name}(&self) -> bool \{ + {{ if flag.readwrite -}} + flags_rust::GetServerConfigurableFlag( + "{flag.device_config_namespace}", + "{flag.device_config_flag}", + "{flag.default_value}") == "true" + {{ -else- }} + {flag.default_value} + {{ -endif }} + } + {{ endfor }} + +} + +/// flag provider +pub static PROVIDER: FlagProvider = FlagProvider; + +{{ for flag in template_flags }} +/// query flag {flag.name} +#[inline(always)] +{{ if flag.readwrite -}} +pub fn {flag.name}() -> bool \{ + PROVIDER.{flag.name}() +{{ -else- }} +pub fn {flag.name}() -> bool \{ + {flag.default_value} +{{ -endif }} +} +{{ endfor }} diff --git a/tools/aconfig/templates/rust_test.template b/tools/aconfig/templates/rust_test.template new file mode 100644 index 0000000000..1e2c28a112 --- /dev/null +++ b/tools/aconfig/templates/rust_test.template @@ -0,0 +1,61 @@ +//! codegenerated rust flag lib + +use std::collections::BTreeMap; +use std::sync::Mutex; + +/// flag provider +pub struct FlagProvider \{ + overrides: BTreeMap<&'static str, bool>, +} + +impl FlagProvider \{ + {{ for flag in template_flags }} + /// query flag {flag.name} + pub fn {flag.name}(&self) -> bool \{ + self.overrides.get("{flag.name}").copied().unwrap_or( + {{ if flag.readwrite -}} + flags_rust::GetServerConfigurableFlag( + "{flag.device_config_namespace}", + "{flag.device_config_flag}", + "{flag.default_value}") == "true" + {{ -else- }} + {flag.default_value} + {{ -endif }} + ) + } + + /// set flag {flag.name} + pub fn set_{flag.name}(&mut self, val: bool) \{ + self.overrides.insert("{flag.name}", val); + } + {{ endfor }} + + /// clear all flag overrides + pub fn reset_flags(&mut self) \{ + self.overrides.clear(); + } +} + +/// flag provider +pub static PROVIDER: Mutex = Mutex::new( + FlagProvider \{overrides: BTreeMap::new()} +); + +{{ for flag in template_flags }} +/// query flag {flag.name} +#[inline(always)] +pub fn {flag.name}() -> bool \{ + PROVIDER.lock().unwrap().{flag.name}() +} + +/// set flag {flag.name} +#[inline(always)] +pub fn set_{flag.name}(val: bool) \{ + PROVIDER.lock().unwrap().set_{flag.name}(val); +} +{{ endfor }} + +/// clear all flag override +pub fn reset_flags() \{ + PROVIDER.lock().unwrap().reset_flags() +}