refactor BlockDifference into common

Move BlockDifference into common and make its script generation code
more complete, so that it can be use by releasetools.py to do diffs on
baseband images.

Bug: 16984795
Change-Id: Iba9afc1c7755458ce47468b5170672612b2cb4b3
This commit is contained in:
Doug Zongker
2014-08-26 10:40:28 -07:00
parent 4d0bfb4f40
commit ab7ca1d286
3 changed files with 130 additions and 60 deletions

View File

@@ -14,6 +14,8 @@ import tempfile
from rangelib import * from rangelib import *
__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
def compute_patch(src, tgt, imgdiff=False): def compute_patch(src, tgt, imgdiff=False):
srcfd, srcfile = tempfile.mkstemp(prefix="src-") srcfd, srcfile = tempfile.mkstemp(prefix="src-")
tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
@@ -60,6 +62,59 @@ class EmptyImage(object):
file_map = {} file_map = {}
def ReadRangeSet(self, ranges): def ReadRangeSet(self, ranges):
return () return ()
def TotalSha1(self):
return sha1().hexdigest()
class DataImage(object):
"""An image wrapped around a single string of data."""
def __init__(self, data, trim=False, pad=False):
self.data = data
self.blocksize = 4096
assert not (trim and pad)
partial = len(self.data) % self.blocksize
if partial > 0:
if trim:
self.data = self.data[:-partial]
elif pad:
self.data += '\0' * (self.blocksize - partial)
else:
raise ValueError(("data for DataImage must be multiple of %d bytes "
"unless trim or pad is specified") %
(self.blocksize,))
assert len(self.data) % self.blocksize == 0
self.total_blocks = len(self.data) / self.blocksize
self.care_map = RangeSet(data=(0, self.total_blocks))
zero_blocks = []
nonzero_blocks = []
reference = '\0' * self.blocksize
for i in range(self.total_blocks):
d = self.data[i*self.blocksize : (i+1)*self.blocksize]
if d == reference:
zero_blocks.append(i)
zero_blocks.append(i+1)
else:
nonzero_blocks.append(i)
nonzero_blocks.append(i+1)
self.file_map = {"__ZERO": RangeSet(zero_blocks),
"__NONZERO": RangeSet(nonzero_blocks)}
def ReadRangeSet(self, ranges):
return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
def TotalSha1(self):
if not hasattr(self, "sha1"):
self.sha1 = sha1(self.data).hexdigest()
return self.sha1
class Transfer(object): class Transfer(object):
def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
@@ -104,6 +159,10 @@ class Transfer(object):
# Implementations are free to break up the data into list/tuple # Implementations are free to break up the data into list/tuple
# elements in any way that is convenient. # elements in any way that is convenient.
# #
# TotalSha1(): a function that returns (as a hex string) the SHA-1
# hash of all the data in the image (ie, all the blocks in the
# care_map)
#
# When creating a BlockImageDiff, the src image may be None, in which # When creating a BlockImageDiff, the src image may be None, in which
# case the list of transfers produced will never read from the # case the list of transfers produced will never read from the
# original image. # original image.
@@ -478,7 +537,12 @@ class BlockImageDiff(object):
# If the blocks written by A are read by B, then B needs to go before A. # If the blocks written by A are read by B, then B needs to go before A.
i = a.tgt_ranges.intersect(b.src_ranges) i = a.tgt_ranges.intersect(b.src_ranges)
if i: if i:
size = i.size() if b.src_name == "__ZERO":
# the cost of removing source blocks for the __ZERO domain
# is (nearly) zero.
size = 0
else:
size = i.size()
b.goes_before[a] = size b.goes_before[a] = size
a.goes_after[b] = size a.goes_after[b] = size
@@ -491,7 +555,8 @@ class BlockImageDiff(object):
# in any file and that are filled with zeros. We have a # in any file and that are filled with zeros. We have a
# special transfer style for zero blocks. # special transfer style for zero blocks.
src_ranges = self.src.file_map.get("__ZERO", empty) src_ranges = self.src.file_map.get("__ZERO", empty)
Transfer(tgt_fn, None, tgt_ranges, src_ranges, "zero", self.transfers) Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
"zero", self.transfers)
continue continue
elif tgt_fn in self.src.file_map: elif tgt_fn in self.src.file_map:

View File

@@ -29,6 +29,8 @@ import threading
import time import time
import zipfile import zipfile
import blockimgdiff
try: try:
from hashlib import sha1 as sha1 from hashlib import sha1 as sha1
except ImportError: except ImportError:
@@ -1010,6 +1012,60 @@ def ComputeDifferences(diffs):
threads.pop().join() threads.pop().join()
class BlockDifference:
def __init__(self, partition, tgt, src=None):
self.tgt = tgt
self.src = src
self.partition = partition
b = blockimgdiff.BlockImageDiff(tgt, src, threads=OPTIONS.worker_threads)
tmpdir = tempfile.mkdtemp()
OPTIONS.tempfiles.append(tmpdir)
self.path = os.path.join(tmpdir, partition)
b.Compute(self.path)
_, self.device = GetTypeAndDevice("/" + partition, OPTIONS.info_dict)
def WriteScript(self, script, output_zip, progress=None):
if not self.src:
# write the output unconditionally
if progress: script.ShowProgress(progress, 0)
self._WriteUpdate(script, output_zip)
else:
script.AppendExtra('if range_sha1("%s", "%s") == "%s" then' %
(self.device, self.src.care_map.to_string_raw(),
self.src.TotalSha1()))
script.Print("Patching %s image..." % (self.partition,))
if progress: script.ShowProgress(progress, 0)
self._WriteUpdate(script, output_zip)
script.AppendExtra(('else\n'
' (range_sha1("%s", "%s") == "%s") ||\n'
' abort("%s partition has unexpected contents");\n'
'endif;') %
(self.device, self.tgt.care_map.to_string_raw(),
self.tgt.TotalSha1(), self.partition))
def _WriteUpdate(self, script, output_zip):
partition = self.partition
with open(self.path + ".transfer.list", "rb") as f:
ZipWriteStr(output_zip, partition + ".transfer.list", f.read())
with open(self.path + ".new.dat", "rb") as f:
ZipWriteStr(output_zip, partition + ".new.dat", f.read())
with open(self.path + ".patch.dat", "rb") as f:
ZipWriteStr(output_zip, partition + ".patch.dat", f.read(),
compression=zipfile.ZIP_STORED)
call = (('block_image_update("%s", '
'package_extract_file("%s.transfer.list"), '
'"%s.new.dat", "%s.patch.dat");\n') %
(self.device, partition, partition, partition))
script.AppendExtra(script._WordWrap(call))
DataImage = blockimgdiff.DataImage
# map recovery.fstab's fs_types to mount/format "partition types" # map recovery.fstab's fs_types to mount/format "partition types"
PARTITION_TYPES = { "yaffs2": "MTD", "mtd": "MTD", PARTITION_TYPES = { "yaffs2": "MTD", "mtd": "MTD",
"ext4": "EMMC", "emmc": "EMMC", "ext4": "EMMC", "emmc": "EMMC",

View File

@@ -455,35 +455,6 @@ def GetImage(which, tmpdir, info_dict):
return sparse_img.SparseImage(path, mappath) return sparse_img.SparseImage(path, mappath)
class BlockDifference:
def __init__(self, partition, tgt, src=None):
self.partition = partition
b = blockimgdiff.BlockImageDiff(tgt, src, threads=OPTIONS.worker_threads)
tmpdir = tempfile.mkdtemp()
OPTIONS.tempfiles.append(tmpdir)
self.path = os.path.join(tmpdir, partition)
b.Compute(self.path)
_, self.device = common.GetTypeAndDevice("/" + partition, OPTIONS.info_dict)
def WriteScript(self, script, output_zip):
partition = self.partition
with open(self.path + ".transfer.list", "rb") as f:
common.ZipWriteStr(output_zip, partition + ".transfer.list", f.read())
with open(self.path + ".new.dat", "rb") as f:
common.ZipWriteStr(output_zip, partition + ".new.dat", f.read())
with open(self.path + ".patch.dat", "rb") as f:
common.ZipWriteStr(output_zip, partition + ".patch.dat", f.read(),
compression=zipfile.ZIP_STORED)
call = (('block_image_update("%s", '
'package_extract_file("%s.transfer.list"), '
'"%s.new.dat", "%s.patch.dat");\n') %
(self.device, partition, partition, partition))
script.AppendExtra(script._WordWrap(call))
def WriteFullOTAPackage(input_zip, output_zip): def WriteFullOTAPackage(input_zip, output_zip):
# TODO: how to determine this? We don't know what version it will # TODO: how to determine this? We don't know what version it will
# be installed on top of. For now, we expect the API just won't # be installed on top of. For now, we expect the API just won't
@@ -586,7 +557,7 @@ else if get_stage("%(bcb_dev)s", "stage") == "3/3" then
# writes incrementals to do it. # writes incrementals to do it.
system_tgt = GetImage("system", OPTIONS.input_tmp, OPTIONS.info_dict) system_tgt = GetImage("system", OPTIONS.input_tmp, OPTIONS.info_dict)
system_tgt.ResetFileMap() system_tgt.ResetFileMap()
system_diff = BlockDifference("system", system_tgt, src=None) system_diff = common.BlockDifference("system", system_tgt, src=None)
system_diff.WriteScript(script, output_zip) system_diff.WriteScript(script, output_zip)
else: else:
script.FormatPartition("/system") script.FormatPartition("/system")
@@ -619,7 +590,7 @@ else if get_stage("%(bcb_dev)s", "stage") == "3/3" then
if block_based: if block_based:
vendor_tgt = GetImage("vendor", OPTIONS.input_tmp, OPTIONS.info_dict) vendor_tgt = GetImage("vendor", OPTIONS.input_tmp, OPTIONS.info_dict)
vendor_tgt.ResetFileMap() vendor_tgt.ResetFileMap()
vendor_diff = BlockDifference("vendor", vendor_tgt) vendor_diff = common.BlockDifference("vendor", vendor_tgt)
vendor_diff.WriteScript(script, output_zip) vendor_diff.WriteScript(script, output_zip)
else: else:
script.FormatPartition("/vendor") script.FormatPartition("/vendor")
@@ -760,14 +731,14 @@ def WriteBlockIncrementalOTAPackage(target_zip, source_zip, output_zip):
system_src = GetImage("system", OPTIONS.source_tmp, OPTIONS.source_info_dict) system_src = GetImage("system", OPTIONS.source_tmp, OPTIONS.source_info_dict)
system_tgt = GetImage("system", OPTIONS.target_tmp, OPTIONS.target_info_dict) system_tgt = GetImage("system", OPTIONS.target_tmp, OPTIONS.target_info_dict)
system_diff = BlockDifference("system", system_tgt, system_src) system_diff = common.BlockDifference("system", system_tgt, system_src)
if HasVendorPartition(target_zip): if HasVendorPartition(target_zip):
if not HasVendorPartition(source_zip): if not HasVendorPartition(source_zip):
raise RuntimeError("can't generate incremental that adds /vendor") raise RuntimeError("can't generate incremental that adds /vendor")
vendor_src = GetImage("vendor", OPTIONS.source_tmp, OPTIONS.source_info_dict) vendor_src = GetImage("vendor", OPTIONS.source_tmp, OPTIONS.source_info_dict)
vendor_tgt = GetImage("vendor", OPTIONS.target_tmp, OPTIONS.target_info_dict) vendor_tgt = GetImage("vendor", OPTIONS.target_tmp, OPTIONS.target_info_dict)
vendor_diff = BlockDifference("vendor", vendor_tgt, vendor_src) vendor_diff = common.BlockDifference("vendor", vendor_tgt, vendor_src)
else: else:
vendor_diff = None vendor_diff = None
@@ -867,32 +838,10 @@ else
device_specific.IncrementalOTA_InstallBegin() device_specific.IncrementalOTA_InstallBegin()
script.AppendExtra('if range_sha1("%s", "%s") == "%s" then' % system_diff.WriteScript(script, output_zip,
(system_diff.device, system_src.care_map.to_string_raw(), progress=0.8 if vendor_diff else 0.9)
system_src.TotalSha1()))
script.Print("Patching system image...")
script.ShowProgress(0.8 if vendor_diff else 0.9, 0)
system_diff.WriteScript(script, output_zip)
script.AppendExtra(('else\n'
' (range_sha1("%s", "%s") == "%s") ||\n'
' abort("system partition has unexpected contents");\n'
'endif;') %
(system_diff.device, system_tgt.care_map.to_string_raw(),
system_tgt.TotalSha1()))
if vendor_diff: if vendor_diff:
script.AppendExtra('if range_sha1("%s", "%s") == "%s" then' % vendor_diff.WriteScript(script, output_zip, progress=0.1)
(vendor_diff.device, vendor_src.care_map.to_string_raw(),
vendor_src.TotalSha1()))
script.Print("Patching vendor image...")
script.ShowProgress(0.1, 0)
vendor_diff.WriteScript(script, output_zip)
script.AppendExtra(('else\n'
' (range_sha1("%s", "%s") == "%s") ||\n'
' abort("vendor partition has unexpected contents");\n'
'endif;') %
(vendor_diff.device, vendor_tgt.care_map.to_string_raw(),
vendor_tgt.TotalSha1()))
if OPTIONS.two_step: if OPTIONS.two_step:
common.ZipWriteStr(output_zip, "boot.img", target_boot.data) common.ZipWriteStr(output_zip, "boot.img", target_boot.data)