Implement SipHasher

This change implements SipHasher in rust and java to make sure same
bytes array will be hashed into the same u64.

The implementation is needed, thus when rust, and java code read the
same flag file, they can find the same entry based on the same key.

Test: atest aconfig_storage_file.test.java aconfig_storage_file.test.cpp
Bug: 352078117
Change-Id: I2ce470039213a09a1df7637e60f4649b053fb2ea
This commit is contained in:
Zhi Dou
2024-08-26 20:45:49 +00:00
parent 58ef2deebf
commit bca30dd13f
6 changed files with 501 additions and 8 deletions

View File

@@ -150,7 +150,7 @@ impl FlagTableNode {
/// Calculate node bucket index
pub fn find_bucket_index(package_id: u32, flag_name: &str, num_buckets: u32) -> u32 {
let full_flag_name = package_id.to_string() + "/" + flag_name;
get_bucket_index(&full_flag_name, num_buckets)
get_bucket_index(full_flag_name.as_bytes(), num_buckets)
}
}

View File

@@ -37,19 +37,20 @@ pub mod flag_table;
pub mod flag_value;
pub mod package_table;
pub mod protos;
pub mod sip_hasher13;
pub mod test_utils;
use anyhow::anyhow;
use std::cmp::Ordering;
use std::collections::hash_map::DefaultHasher;
use std::fs::File;
use std::hash::{Hash, Hasher};
use std::hash::Hasher;
use std::io::Read;
pub use crate::flag_info::{FlagInfoBit, FlagInfoHeader, FlagInfoList, FlagInfoNode};
pub use crate::flag_table::{FlagTable, FlagTableHeader, FlagTableNode};
pub use crate::flag_value::{FlagValueHeader, FlagValueList};
pub use crate::package_table::{PackageTable, PackageTableHeader, PackageTableNode};
pub use crate::sip_hasher13::SipHasher13;
use crate::AconfigStorageError::{
BytesParseFail, HashTableSizeLimit, InvalidFlagValueType, InvalidStoredFlagType,
@@ -211,10 +212,12 @@ pub fn get_table_size(entries: u32) -> Result<u32, AconfigStorageError> {
}
/// Get the corresponding bucket index given the key and number of buckets
pub(crate) fn get_bucket_index<T: Hash>(val: &T, num_buckets: u32) -> u32 {
let mut s = DefaultHasher::new();
val.hash(&mut s);
(s.finish() % num_buckets as u64) as u32
pub(crate) fn get_bucket_index(val: &[u8], num_buckets: u32) -> u32 {
let mut s = SipHasher13::new();
s.write(val);
s.write_u8(0xff);
let ret = (s.finish() % num_buckets as u64) as u32;
ret
}
/// Read and parse bytes as u8

View File

@@ -146,7 +146,7 @@ impl PackageTableNode {
/// construction side (aconfig binary) and consumption side (flag read lib)
/// use the same method of hashing
pub fn find_bucket_index(package: &str, num_buckets: u32) -> u32 {
get_bucket_index(&package, num_buckets)
get_bucket_index(package.as_bytes(), num_buckets)
}
}

View File

@@ -0,0 +1,327 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//! An implementation of SipHash13
use std::cmp;
use std::mem;
use std::ptr;
use std::slice;
use std::hash::Hasher;
/// An implementation of SipHash 2-4.
///
#[derive(Debug, Clone, Default)]
pub struct SipHasher13 {
k0: u64,
k1: u64,
length: usize, // how many bytes we've processed
state: State, // hash State
tail: u64, // unprocessed bytes le
ntail: usize, // how many bytes in tail are valid
}
#[derive(Debug, Clone, Copy, Default)]
#[repr(C)]
struct State {
// v0, v2 and v1, v3 show up in pairs in the algorithm,
// and simd implementations of SipHash will use vectors
// of v02 and v13. By placing them in this order in the struct,
// the compiler can pick up on just a few simd optimizations by itself.
v0: u64,
v2: u64,
v1: u64,
v3: u64,
}
macro_rules! compress {
($state:expr) => {{
compress!($state.v0, $state.v1, $state.v2, $state.v3)
}};
($v0:expr, $v1:expr, $v2:expr, $v3:expr) => {{
$v0 = $v0.wrapping_add($v1);
$v1 = $v1.rotate_left(13);
$v1 ^= $v0;
$v0 = $v0.rotate_left(32);
$v2 = $v2.wrapping_add($v3);
$v3 = $v3.rotate_left(16);
$v3 ^= $v2;
$v0 = $v0.wrapping_add($v3);
$v3 = $v3.rotate_left(21);
$v3 ^= $v0;
$v2 = $v2.wrapping_add($v1);
$v1 = $v1.rotate_left(17);
$v1 ^= $v2;
$v2 = $v2.rotate_left(32);
}};
}
/// Load an integer of the desired type from a byte stream, in LE order. Uses
/// `copy_nonoverlapping` to let the compiler generate the most efficient way
/// to load it from a possibly unaligned address.
///
/// Unsafe because: unchecked indexing at i..i+size_of(int_ty)
macro_rules! load_int_le {
($buf:expr, $i:expr, $int_ty:ident) => {{
debug_assert!($i + mem::size_of::<$int_ty>() <= $buf.len());
let mut data = 0 as $int_ty;
ptr::copy_nonoverlapping(
$buf.get_unchecked($i),
&mut data as *mut _ as *mut u8,
mem::size_of::<$int_ty>(),
);
data.to_le()
}};
}
/// Load an u64 using up to 7 bytes of a byte slice.
///
/// Unsafe because: unchecked indexing at start..start+len
#[inline]
unsafe fn u8to64_le(buf: &[u8], start: usize, len: usize) -> u64 {
debug_assert!(len < 8);
let mut i = 0; // current byte index (from LSB) in the output u64
let mut out = 0;
if i + 3 < len {
out = load_int_le!(buf, start + i, u32) as u64;
i += 4;
}
if i + 1 < len {
out |= (load_int_le!(buf, start + i, u16) as u64) << (i * 8);
i += 2
}
if i < len {
out |= (*buf.get_unchecked(start + i) as u64) << (i * 8);
i += 1;
}
debug_assert_eq!(i, len);
out
}
impl SipHasher13 {
/// Creates a new `SipHasher13` with the two initial keys set to 0.
#[inline]
pub fn new() -> SipHasher13 {
SipHasher13::new_with_keys(0, 0)
}
/// Creates a `SipHasher13` that is keyed off the provided keys.
#[inline]
pub fn new_with_keys(key0: u64, key1: u64) -> SipHasher13 {
let mut sip_hasher = SipHasher13 {
k0: key0,
k1: key1,
length: 0,
state: State { v0: 0, v1: 0, v2: 0, v3: 0 },
tail: 0,
ntail: 0,
};
sip_hasher.reset();
sip_hasher
}
#[inline]
fn c_rounds(state: &mut State) {
compress!(state);
}
#[inline]
fn d_rounds(state: &mut State) {
compress!(state);
compress!(state);
compress!(state);
}
#[inline]
fn reset(&mut self) {
self.length = 0;
self.state.v0 = self.k0 ^ 0x736f6d6570736575;
self.state.v1 = self.k1 ^ 0x646f72616e646f6d;
self.state.v2 = self.k0 ^ 0x6c7967656e657261;
self.state.v3 = self.k1 ^ 0x7465646279746573;
self.ntail = 0;
}
// Specialized write function that is only valid for buffers with len <= 8.
// It's used to force inlining of write_u8 and write_usize, those would normally be inlined
// except for composite types (that includes slices and str hashing because of delimiter).
// Without this extra push the compiler is very reluctant to inline delimiter writes,
// degrading performance substantially for the most common use cases.
#[inline]
fn short_write(&mut self, msg: &[u8]) {
debug_assert!(msg.len() <= 8);
let length = msg.len();
self.length += length;
let needed = 8 - self.ntail;
let fill = cmp::min(length, needed);
if fill == 8 {
// safe to call since msg hasn't been loaded
self.tail = unsafe { load_int_le!(msg, 0, u64) };
} else {
// safe to call since msg hasn't been loaded, and fill <= msg.len()
self.tail |= unsafe { u8to64_le(msg, 0, fill) } << (8 * self.ntail);
if length < needed {
self.ntail += length;
return;
}
}
self.state.v3 ^= self.tail;
Self::c_rounds(&mut self.state);
self.state.v0 ^= self.tail;
// Buffered tail is now flushed, process new input.
self.ntail = length - needed;
// safe to call since number of `needed` bytes has been loaded
// and self.ntail + needed == msg.len()
self.tail = unsafe { u8to64_le(msg, needed, self.ntail) };
}
}
impl Hasher for SipHasher13 {
// see short_write comment for explanation
#[inline]
fn write_usize(&mut self, i: usize) {
// safe to call, since convert the pointer to u8
let bytes = unsafe {
slice::from_raw_parts(&i as *const usize as *const u8, mem::size_of::<usize>())
};
self.short_write(bytes);
}
// see short_write comment for explanation
#[inline]
fn write_u8(&mut self, i: u8) {
self.short_write(&[i]);
}
#[inline]
fn write(&mut self, msg: &[u8]) {
let length = msg.len();
self.length += length;
let mut needed = 0;
// loading unprocessed byte from last write
if self.ntail != 0 {
needed = 8 - self.ntail;
// safe to call, since msg hasn't been processed
// and cmp::min(length, needed) < 8
self.tail |= unsafe { u8to64_le(msg, 0, cmp::min(length, needed)) } << 8 * self.ntail;
if length < needed {
self.ntail += length;
return;
} else {
self.state.v3 ^= self.tail;
Self::c_rounds(&mut self.state);
self.state.v0 ^= self.tail;
self.ntail = 0;
}
}
// Buffered tail is now flushed, process new input.
let len = length - needed;
let left = len & 0x7;
let mut i = needed;
while i < len - left {
// safe to call since if i < len - left, it means msg has at least 1 byte to load
let mi = unsafe { load_int_le!(msg, i, u64) };
self.state.v3 ^= mi;
Self::c_rounds(&mut self.state);
self.state.v0 ^= mi;
i += 8;
}
// safe to call since if left == 0, since this call will load nothing
// if left > 0, it means there are number of `left` bytes in msg
self.tail = unsafe { u8to64_le(msg, i, left) };
self.ntail = left;
}
#[inline]
fn finish(&self) -> u64 {
let mut state = self.state;
let b: u64 = ((self.length as u64 & 0xff) << 56) | self.tail;
state.v3 ^= b;
Self::c_rounds(&mut state);
state.v0 ^= b;
state.v2 ^= 0xff;
Self::d_rounds(&mut state);
state.v0 ^ state.v1 ^ state.v2 ^ state.v3
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::hash::{Hash, Hasher};
use std::string::String;
#[test]
// this test point locks down the value list serialization
fn test_sip_hash13_string_hash() {
let mut sip_hash13 = SipHasher13::new();
let test_str1 = String::from("com.google.android.test");
test_str1.hash(&mut sip_hash13);
assert_eq!(17898838669067067585, sip_hash13.finish());
let test_str2 = String::from("adfadfadf adfafadadf 1231241241");
test_str2.hash(&mut sip_hash13);
assert_eq!(13543518987672889310, sip_hash13.finish());
}
#[test]
fn test_sip_hash13_write() {
let mut sip_hash13 = SipHasher13::new();
let test_str1 = String::from("com.google.android.test");
sip_hash13.write(test_str1.as_bytes());
sip_hash13.write_u8(0xff);
assert_eq!(17898838669067067585, sip_hash13.finish());
let mut sip_hash132 = SipHasher13::new();
let test_str1 = String::from("com.google.android.test");
sip_hash132.write(test_str1.as_bytes());
assert_eq!(9685440969685209025, sip_hash132.finish());
sip_hash132.write(test_str1.as_bytes());
assert_eq!(6719694176662736568, sip_hash132.finish());
let mut sip_hash133 = SipHasher13::new();
let test_str2 = String::from("abcdefg");
test_str2.hash(&mut sip_hash133);
assert_eq!(2492161047327640297, sip_hash133.finish());
let mut sip_hash134 = SipHasher13::new();
let test_str3 = String::from("abcdefgh");
test_str3.hash(&mut sip_hash134);
assert_eq!(6689927370435554326, sip_hash134.finish());
}
#[test]
fn test_sip_hash13_write_short() {
let mut sip_hash13 = SipHasher13::new();
sip_hash13.write_u8(0x61);
assert_eq!(4644417185603328019, sip_hash13.finish());
}
}

View File

@@ -0,0 +1,119 @@
/*
* Copyright (C) 2024 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.aconfig.storage;
public class SipHasher13 {
static class State {
private long v0;
private long v2;
private long v1;
private long v3;
public State(long k0, long k1) {
v0 = k0 ^ 0x736f6d6570736575L;
v1 = k1 ^ 0x646f72616e646f6dL;
v2 = k0 ^ 0x6c7967656e657261L;
v3 = k1 ^ 0x7465646279746573L;
}
public void compress(long m) {
v3 ^= m;
cRounds();
v0 ^= m;
}
public long finish() {
v2 ^= 0xff;
dRounds();
return v0 ^ v1 ^ v2 ^ v3;
}
private void cRounds() {
v0 += v1;
v1 = rotateLeft(v1, 13);
v1 ^= v0;
v0 = rotateLeft(v0, 32);
v2 += v3;
v3 = rotateLeft(v3, 16);
v3 ^= v2;
v0 += v3;
v3 = rotateLeft(v3, 21);
v3 ^= v0;
v2 += v1;
v1 = rotateLeft(v1, 17);
v1 ^= v2;
v2 = rotateLeft(v2, 32);
}
private void dRounds() {
for (int i = 0; i < 3; i++) {
v0 += v1;
v1 = rotateLeft(v1, 13);
v1 ^= v0;
v0 = rotateLeft(v0, 32);
v2 += v3;
v3 = rotateLeft(v3, 16);
v3 ^= v2;
v0 += v3;
v3 = rotateLeft(v3, 21);
v3 ^= v0;
v2 += v1;
v1 = rotateLeft(v1, 17);
v1 ^= v2;
v2 = rotateLeft(v2, 32);
}
}
private static long rotateLeft(long value, int shift) {
return (value << shift) | value >>> (64 - shift);
}
}
public static long hash(byte[] data) {
State state = new State(0, 0);
int len = data.length;
int left = len & 0x7;
int index = 0;
while (index < len - left) {
long mi = loadLe(data, index, 8);
index += 8;
state.compress(mi);
}
// padding the end with 0xff to be consistent with rust
long m = (0xffL << (left * 8)) | loadLe(data, index, left);
if (left == 0x7) {
// compress the m w-2
state.compress(m);
m = 0L;
}
// len adds 1 since padded 0xff
m |= (((len + 1) & 0xffL) << 56);
state.compress(m);
return state.finish();
}
private static long loadLe(byte[] data, int offset, int size) {
long m = 0;
for (int i = 0; i < size; i++) {
m |= (data[i + offset] & 0xffL) << (i * 8);
}
return m;
}
}

View File

@@ -0,0 +1,44 @@
/*
* Copyright (C) 2024 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.aconfig.storage.test;
import static org.junit.Assert.assertEquals;
import static java.nio.charset.StandardCharsets.UTF_8;
import android.aconfig.storage.SipHasher13;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class SipHasher13Test {
@Test
public void testSipHash_hashString() throws Exception {
String testStr = "com.google.android.test";
long result = SipHasher13.hash(testStr.getBytes(UTF_8));
assertEquals(0xF86572EFF9C4A0C1L, result);
testStr = "abcdefg";
result = SipHasher13.hash(testStr.getBytes(UTF_8));
assertEquals(0x2295EF44BD078AE9L, result);
testStr = "abcdefgh";
result = SipHasher13.hash(testStr.getBytes(UTF_8));
assertEquals(0x5CD7657FA7F96C16L, result);
}
}