diff --git a/tools/aconfig/aconfig_storage_file/src/flag_table.rs b/tools/aconfig/aconfig_storage_file/src/flag_table.rs index 64b90eabfa..660edac0c7 100644 --- a/tools/aconfig/aconfig_storage_file/src/flag_table.rs +++ b/tools/aconfig/aconfig_storage_file/src/flag_table.rs @@ -150,7 +150,7 @@ impl FlagTableNode { /// Calculate node bucket index pub fn find_bucket_index(package_id: u32, flag_name: &str, num_buckets: u32) -> u32 { let full_flag_name = package_id.to_string() + "/" + flag_name; - get_bucket_index(&full_flag_name, num_buckets) + get_bucket_index(full_flag_name.as_bytes(), num_buckets) } } diff --git a/tools/aconfig/aconfig_storage_file/src/lib.rs b/tools/aconfig/aconfig_storage_file/src/lib.rs index 26e9c1a3be..b6367ffa35 100644 --- a/tools/aconfig/aconfig_storage_file/src/lib.rs +++ b/tools/aconfig/aconfig_storage_file/src/lib.rs @@ -37,19 +37,20 @@ pub mod flag_table; pub mod flag_value; pub mod package_table; pub mod protos; +pub mod sip_hasher13; pub mod test_utils; use anyhow::anyhow; use std::cmp::Ordering; -use std::collections::hash_map::DefaultHasher; use std::fs::File; -use std::hash::{Hash, Hasher}; +use std::hash::Hasher; use std::io::Read; pub use crate::flag_info::{FlagInfoBit, FlagInfoHeader, FlagInfoList, FlagInfoNode}; pub use crate::flag_table::{FlagTable, FlagTableHeader, FlagTableNode}; pub use crate::flag_value::{FlagValueHeader, FlagValueList}; pub use crate::package_table::{PackageTable, PackageTableHeader, PackageTableNode}; +pub use crate::sip_hasher13::SipHasher13; use crate::AconfigStorageError::{ BytesParseFail, HashTableSizeLimit, InvalidFlagValueType, InvalidStoredFlagType, @@ -211,10 +212,12 @@ pub fn get_table_size(entries: u32) -> Result { } /// Get the corresponding bucket index given the key and number of buckets -pub(crate) fn get_bucket_index(val: &T, num_buckets: u32) -> u32 { - let mut s = DefaultHasher::new(); - val.hash(&mut s); - (s.finish() % num_buckets as u64) as u32 +pub(crate) fn get_bucket_index(val: &[u8], num_buckets: u32) -> u32 { + let mut s = SipHasher13::new(); + s.write(val); + s.write_u8(0xff); + let ret = (s.finish() % num_buckets as u64) as u32; + ret } /// Read and parse bytes as u8 diff --git a/tools/aconfig/aconfig_storage_file/src/package_table.rs b/tools/aconfig/aconfig_storage_file/src/package_table.rs index b734972f33..007f86ed1a 100644 --- a/tools/aconfig/aconfig_storage_file/src/package_table.rs +++ b/tools/aconfig/aconfig_storage_file/src/package_table.rs @@ -146,7 +146,7 @@ impl PackageTableNode { /// construction side (aconfig binary) and consumption side (flag read lib) /// use the same method of hashing pub fn find_bucket_index(package: &str, num_buckets: u32) -> u32 { - get_bucket_index(&package, num_buckets) + get_bucket_index(package.as_bytes(), num_buckets) } } diff --git a/tools/aconfig/aconfig_storage_file/src/sip_hasher13.rs b/tools/aconfig/aconfig_storage_file/src/sip_hasher13.rs new file mode 100644 index 0000000000..9be3175e18 --- /dev/null +++ b/tools/aconfig/aconfig_storage_file/src/sip_hasher13.rs @@ -0,0 +1,327 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! An implementation of SipHash13 + +use std::cmp; +use std::mem; +use std::ptr; +use std::slice; + +use std::hash::Hasher; + +/// An implementation of SipHash 2-4. +/// +#[derive(Debug, Clone, Default)] +pub struct SipHasher13 { + k0: u64, + k1: u64, + length: usize, // how many bytes we've processed + state: State, // hash State + tail: u64, // unprocessed bytes le + ntail: usize, // how many bytes in tail are valid +} + +#[derive(Debug, Clone, Copy, Default)] +#[repr(C)] +struct State { + // v0, v2 and v1, v3 show up in pairs in the algorithm, + // and simd implementations of SipHash will use vectors + // of v02 and v13. By placing them in this order in the struct, + // the compiler can pick up on just a few simd optimizations by itself. + v0: u64, + v2: u64, + v1: u64, + v3: u64, +} + +macro_rules! compress { + ($state:expr) => {{ + compress!($state.v0, $state.v1, $state.v2, $state.v3) + }}; + ($v0:expr, $v1:expr, $v2:expr, $v3:expr) => {{ + $v0 = $v0.wrapping_add($v1); + $v1 = $v1.rotate_left(13); + $v1 ^= $v0; + $v0 = $v0.rotate_left(32); + $v2 = $v2.wrapping_add($v3); + $v3 = $v3.rotate_left(16); + $v3 ^= $v2; + $v0 = $v0.wrapping_add($v3); + $v3 = $v3.rotate_left(21); + $v3 ^= $v0; + $v2 = $v2.wrapping_add($v1); + $v1 = $v1.rotate_left(17); + $v1 ^= $v2; + $v2 = $v2.rotate_left(32); + }}; +} + +/// Load an integer of the desired type from a byte stream, in LE order. Uses +/// `copy_nonoverlapping` to let the compiler generate the most efficient way +/// to load it from a possibly unaligned address. +/// +/// Unsafe because: unchecked indexing at i..i+size_of(int_ty) +macro_rules! load_int_le { + ($buf:expr, $i:expr, $int_ty:ident) => {{ + debug_assert!($i + mem::size_of::<$int_ty>() <= $buf.len()); + let mut data = 0 as $int_ty; + ptr::copy_nonoverlapping( + $buf.get_unchecked($i), + &mut data as *mut _ as *mut u8, + mem::size_of::<$int_ty>(), + ); + data.to_le() + }}; +} + +/// Load an u64 using up to 7 bytes of a byte slice. +/// +/// Unsafe because: unchecked indexing at start..start+len +#[inline] +unsafe fn u8to64_le(buf: &[u8], start: usize, len: usize) -> u64 { + debug_assert!(len < 8); + let mut i = 0; // current byte index (from LSB) in the output u64 + let mut out = 0; + if i + 3 < len { + out = load_int_le!(buf, start + i, u32) as u64; + i += 4; + } + if i + 1 < len { + out |= (load_int_le!(buf, start + i, u16) as u64) << (i * 8); + i += 2 + } + if i < len { + out |= (*buf.get_unchecked(start + i) as u64) << (i * 8); + i += 1; + } + debug_assert_eq!(i, len); + out +} + +impl SipHasher13 { + /// Creates a new `SipHasher13` with the two initial keys set to 0. + #[inline] + pub fn new() -> SipHasher13 { + SipHasher13::new_with_keys(0, 0) + } + + /// Creates a `SipHasher13` that is keyed off the provided keys. + #[inline] + pub fn new_with_keys(key0: u64, key1: u64) -> SipHasher13 { + let mut sip_hasher = SipHasher13 { + k0: key0, + k1: key1, + length: 0, + state: State { v0: 0, v1: 0, v2: 0, v3: 0 }, + tail: 0, + ntail: 0, + }; + sip_hasher.reset(); + sip_hasher + } + + #[inline] + fn c_rounds(state: &mut State) { + compress!(state); + } + + #[inline] + fn d_rounds(state: &mut State) { + compress!(state); + compress!(state); + compress!(state); + } + + #[inline] + fn reset(&mut self) { + self.length = 0; + self.state.v0 = self.k0 ^ 0x736f6d6570736575; + self.state.v1 = self.k1 ^ 0x646f72616e646f6d; + self.state.v2 = self.k0 ^ 0x6c7967656e657261; + self.state.v3 = self.k1 ^ 0x7465646279746573; + self.ntail = 0; + } + + // Specialized write function that is only valid for buffers with len <= 8. + // It's used to force inlining of write_u8 and write_usize, those would normally be inlined + // except for composite types (that includes slices and str hashing because of delimiter). + // Without this extra push the compiler is very reluctant to inline delimiter writes, + // degrading performance substantially for the most common use cases. + #[inline] + fn short_write(&mut self, msg: &[u8]) { + debug_assert!(msg.len() <= 8); + let length = msg.len(); + self.length += length; + + let needed = 8 - self.ntail; + let fill = cmp::min(length, needed); + if fill == 8 { + // safe to call since msg hasn't been loaded + self.tail = unsafe { load_int_le!(msg, 0, u64) }; + } else { + // safe to call since msg hasn't been loaded, and fill <= msg.len() + self.tail |= unsafe { u8to64_le(msg, 0, fill) } << (8 * self.ntail); + if length < needed { + self.ntail += length; + return; + } + } + self.state.v3 ^= self.tail; + Self::c_rounds(&mut self.state); + self.state.v0 ^= self.tail; + + // Buffered tail is now flushed, process new input. + self.ntail = length - needed; + // safe to call since number of `needed` bytes has been loaded + // and self.ntail + needed == msg.len() + self.tail = unsafe { u8to64_le(msg, needed, self.ntail) }; + } +} + +impl Hasher for SipHasher13 { + // see short_write comment for explanation + #[inline] + fn write_usize(&mut self, i: usize) { + // safe to call, since convert the pointer to u8 + let bytes = unsafe { + slice::from_raw_parts(&i as *const usize as *const u8, mem::size_of::()) + }; + self.short_write(bytes); + } + + // see short_write comment for explanation + #[inline] + fn write_u8(&mut self, i: u8) { + self.short_write(&[i]); + } + + #[inline] + fn write(&mut self, msg: &[u8]) { + let length = msg.len(); + self.length += length; + + let mut needed = 0; + + // loading unprocessed byte from last write + if self.ntail != 0 { + needed = 8 - self.ntail; + // safe to call, since msg hasn't been processed + // and cmp::min(length, needed) < 8 + self.tail |= unsafe { u8to64_le(msg, 0, cmp::min(length, needed)) } << 8 * self.ntail; + if length < needed { + self.ntail += length; + return; + } else { + self.state.v3 ^= self.tail; + Self::c_rounds(&mut self.state); + self.state.v0 ^= self.tail; + self.ntail = 0; + } + } + + // Buffered tail is now flushed, process new input. + let len = length - needed; + let left = len & 0x7; + + let mut i = needed; + while i < len - left { + // safe to call since if i < len - left, it means msg has at least 1 byte to load + let mi = unsafe { load_int_le!(msg, i, u64) }; + + self.state.v3 ^= mi; + Self::c_rounds(&mut self.state); + self.state.v0 ^= mi; + + i += 8; + } + + // safe to call since if left == 0, since this call will load nothing + // if left > 0, it means there are number of `left` bytes in msg + self.tail = unsafe { u8to64_le(msg, i, left) }; + self.ntail = left; + } + + #[inline] + fn finish(&self) -> u64 { + let mut state = self.state; + + let b: u64 = ((self.length as u64 & 0xff) << 56) | self.tail; + + state.v3 ^= b; + Self::c_rounds(&mut state); + state.v0 ^= b; + + state.v2 ^= 0xff; + Self::d_rounds(&mut state); + + state.v0 ^ state.v1 ^ state.v2 ^ state.v3 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::hash::{Hash, Hasher}; + use std::string::String; + + #[test] + // this test point locks down the value list serialization + fn test_sip_hash13_string_hash() { + let mut sip_hash13 = SipHasher13::new(); + let test_str1 = String::from("com.google.android.test"); + test_str1.hash(&mut sip_hash13); + assert_eq!(17898838669067067585, sip_hash13.finish()); + + let test_str2 = String::from("adfadfadf adfafadadf 1231241241"); + test_str2.hash(&mut sip_hash13); + assert_eq!(13543518987672889310, sip_hash13.finish()); + } + + #[test] + fn test_sip_hash13_write() { + let mut sip_hash13 = SipHasher13::new(); + let test_str1 = String::from("com.google.android.test"); + sip_hash13.write(test_str1.as_bytes()); + sip_hash13.write_u8(0xff); + assert_eq!(17898838669067067585, sip_hash13.finish()); + + let mut sip_hash132 = SipHasher13::new(); + let test_str1 = String::from("com.google.android.test"); + sip_hash132.write(test_str1.as_bytes()); + assert_eq!(9685440969685209025, sip_hash132.finish()); + sip_hash132.write(test_str1.as_bytes()); + assert_eq!(6719694176662736568, sip_hash132.finish()); + + let mut sip_hash133 = SipHasher13::new(); + let test_str2 = String::from("abcdefg"); + test_str2.hash(&mut sip_hash133); + assert_eq!(2492161047327640297, sip_hash133.finish()); + + let mut sip_hash134 = SipHasher13::new(); + let test_str3 = String::from("abcdefgh"); + test_str3.hash(&mut sip_hash134); + assert_eq!(6689927370435554326, sip_hash134.finish()); + } + + #[test] + fn test_sip_hash13_write_short() { + let mut sip_hash13 = SipHasher13::new(); + sip_hash13.write_u8(0x61); + assert_eq!(4644417185603328019, sip_hash13.finish()); + } +} diff --git a/tools/aconfig/aconfig_storage_file/srcs/android/aconfig/storage/SipHasher13.java b/tools/aconfig/aconfig_storage_file/srcs/android/aconfig/storage/SipHasher13.java new file mode 100644 index 0000000000..8faee58a76 --- /dev/null +++ b/tools/aconfig/aconfig_storage_file/srcs/android/aconfig/storage/SipHasher13.java @@ -0,0 +1,119 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.aconfig.storage; + +public class SipHasher13 { + static class State { + private long v0; + private long v2; + private long v1; + private long v3; + + public State(long k0, long k1) { + v0 = k0 ^ 0x736f6d6570736575L; + v1 = k1 ^ 0x646f72616e646f6dL; + v2 = k0 ^ 0x6c7967656e657261L; + v3 = k1 ^ 0x7465646279746573L; + } + + public void compress(long m) { + v3 ^= m; + cRounds(); + v0 ^= m; + } + + public long finish() { + v2 ^= 0xff; + dRounds(); + return v0 ^ v1 ^ v2 ^ v3; + } + + private void cRounds() { + v0 += v1; + v1 = rotateLeft(v1, 13); + v1 ^= v0; + v0 = rotateLeft(v0, 32); + v2 += v3; + v3 = rotateLeft(v3, 16); + v3 ^= v2; + v0 += v3; + v3 = rotateLeft(v3, 21); + v3 ^= v0; + v2 += v1; + v1 = rotateLeft(v1, 17); + v1 ^= v2; + v2 = rotateLeft(v2, 32); + } + + private void dRounds() { + for (int i = 0; i < 3; i++) { + v0 += v1; + v1 = rotateLeft(v1, 13); + v1 ^= v0; + v0 = rotateLeft(v0, 32); + v2 += v3; + v3 = rotateLeft(v3, 16); + v3 ^= v2; + v0 += v3; + v3 = rotateLeft(v3, 21); + v3 ^= v0; + v2 += v1; + v1 = rotateLeft(v1, 17); + v1 ^= v2; + v2 = rotateLeft(v2, 32); + } + } + + private static long rotateLeft(long value, int shift) { + return (value << shift) | value >>> (64 - shift); + } + } + + public static long hash(byte[] data) { + State state = new State(0, 0); + int len = data.length; + int left = len & 0x7; + int index = 0; + + while (index < len - left) { + long mi = loadLe(data, index, 8); + index += 8; + state.compress(mi); + } + + // padding the end with 0xff to be consistent with rust + long m = (0xffL << (left * 8)) | loadLe(data, index, left); + if (left == 0x7) { + // compress the m w-2 + state.compress(m); + m = 0L; + } + // len adds 1 since padded 0xff + m |= (((len + 1) & 0xffL) << 56); + state.compress(m); + + return state.finish(); + } + + private static long loadLe(byte[] data, int offset, int size) { + long m = 0; + for (int i = 0; i < size; i++) { + m |= (data[i + offset] & 0xffL) << (i * 8); + } + return m; + } +} diff --git a/tools/aconfig/aconfig_storage_file/tests/srcs/SipHasher13Test.java b/tools/aconfig/aconfig_storage_file/tests/srcs/SipHasher13Test.java new file mode 100644 index 0000000000..10620d272b --- /dev/null +++ b/tools/aconfig/aconfig_storage_file/tests/srcs/SipHasher13Test.java @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.aconfig.storage.test; + +import static org.junit.Assert.assertEquals; +import static java.nio.charset.StandardCharsets.UTF_8; + +import android.aconfig.storage.SipHasher13; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class SipHasher13Test { + @Test + public void testSipHash_hashString() throws Exception { + String testStr = "com.google.android.test"; + long result = SipHasher13.hash(testStr.getBytes(UTF_8)); + assertEquals(0xF86572EFF9C4A0C1L, result); + + testStr = "abcdefg"; + result = SipHasher13.hash(testStr.getBytes(UTF_8)); + assertEquals(0x2295EF44BD078AE9L, result); + + testStr = "abcdefgh"; + result = SipHasher13.hash(testStr.getBytes(UTF_8)); + assertEquals(0x5CD7657FA7F96C16L, result); + } +}