Source code for geetools.bitreader

# coding=utf-8
"""Bit Reader module."""
import ee
import ee.data

import geetools.tools as tools


[docs] class BitReader(object): """Bit Reader. Initializes with parameter `options`, which must be a dictionary with the following format: keys must be a str with the bits places, example: '0-1' means bit 0 and bit 1 values must be a dictionary with the bit value as the key and the category (str) as value. Categories must be unique. - Encode: given a category/categories return a list of possible values - Decode: given a value return a list of categories Example: MOD09 (http://modis-sr.ltdri.org/guide/MOD09_UserGuide_v1_3.pdf) (page 28, state1km, 16 bits): ``` options = { '0-1': {0:'clear', 1:'cloud', 2:'mix'}, '2-2': {1:'shadow'}, '8-9': {1:'small_cirrus', 2:'average_cirrus', 3:'high_cirrus'} } reader = BitReader(options, 16) print(reader.decode(204)) ``` >>['shadow', 'clear'] ``` print(reader.match(204, 'cloud') ``` >>False """ @staticmethod
[docs] def getBin(bit, nbits=None, shift=0): """from https://stackoverflow.com/questions/699866/python-int-to-binary.""" pure = bin(bit)[2:] if not nbits: nbits = len(pure) lpure = len(pure) admited_shift = nbits - lpure if admited_shift < 0: mje = ( "the number of bits must be more than the bits" " representation of the number. {} ({}) can't be" " represented in {} bits" ) raise ValueError(mje.format(pure, bit, nbits)) if shift > admited_shift: mje = "can't shift {} places for bit {} ({})" raise ValueError(mje.format(shift, pure, bit)) if shift: shifted = bin(int(pure, 2) << shift)[2:] else: shifted = pure return shifted.zfill(nbits)
@staticmethod
[docs] def decodeKey(key): """decodes an option's key into a list.""" if isinstance(key, (str,)): bits = key.split("-") try: ini = int(bits[0]) if len(bits) == 1: end = ini else: end = int(bits[1]) except Exception: mje = ( 'keys must be with the following format "bit-bit", ' 'example "0-1" (found {})' ) raise ValueError(mje.format(key)) bits_list = range(ini, end + 1) return bits_list elif isinstance(key, (int, float)): value = int(key) return (value, value + 1)
def __init__(self, options, bit_length=None): """TODO missing docstring.""" self.options = options def allBits(): """get a list of all bits and check consistance.""" all_values = [x for key in options.keys() for x in self.decodeKey(key)] for val in all_values: n = all_values.count(val) if n > 1: mje = ( "bits must not overlap. Example: {'0-1':.., " "'2-3':..} and NOT {'0-1':.., '1-3':..}" ) raise ValueError(mje) return all_values ## Check if categories repeat and create property all_categories # TODO: reformat categories if find spaces or uppercases all_cat = [] for key, val in self.options.items(): for i, cat in val.items(): if cat in all_cat: msg = 'Classes must be unique, found "{}" twice' raise ValueError(msg.format(cat)) all_cat.append(cat) self.all_categories = all_cat ### all_values = allBits() self.bit_length = ( len(range(min(all_values), max(all_values) + 1)) if not bit_length else bit_length ) self.max = 2**self.bit_length info = {} for key, val in options.items(): bits_list = self.decodeKey(key) bit_length_cat = len(bits_list) for i, cat in val.items(): info[cat] = { "bit_length": bit_length_cat, "lshift": bits_list[0], "shifted": i, } self.info = info
[docs] def encode(self, cat): """Given a category, return the encoded value (only).""" info = self.info[cat] lshift = info["lshift"] decoded = info["shifted"] shifted = decoded << lshift return shifted
[docs] def encodeBand(self, category, mask, name=None): """Make an image in which all pixels have the value for the given. category. :param category: the category to encode :type category: str :param mask: the mask that indicates which pixels encode :type mask: ee.Image :param name: name of the resulting band. If None it'll be the same as 'mask' :type name: str :return: A one band image :rtype: ee.Image """ encoded = self.encode(category) if not name: name = mask.bandNames().get(0) image = tools.image.empty(encoded, [name]) return image.updateMask(mask)
[docs] def encodeAnd(self, *args): """Decodes a combination of the given categories. returns a list of. possible values . """ first = args[0] values_first = self.encodeOne(first) def get_match(list1, list2): return [val for val in list2 if val in list1] result = values_first for cat in args[1:]: values = self.encodeOne(cat) result = get_match(result, values) return result
[docs] def encodeOr(self, *args): """Decodes a combination of the given categories. returns a list of. possible values . """ first = args[0] values_first = self.encodeOne(first) for cat in args[1:]: values = self.encodeOne(cat) for value in values: if value not in values_first: values_first.append(value) return values_first
[docs] def encodeNot(self, *args): """Given a set of categories return a list of values that DO NOT. match with any . """ result = [] match = self.encodeOr(*args) for bit in range(self.max): if bit not in match: result.append(bit) return result
[docs] def encodeOne(self, cat): """Given a category, return a list of values that match it.""" info = self.info[cat] lshift = info["lshift"] length = info["bit_length"] decoded = info["shifted"] result = [] for bit in range(self.max): move = lshift + length rest = bit >> move << move norest = bit - rest to_compare = norest >> lshift if to_compare == decoded: result.append(bit) return result
[docs] def decode(self, value): """given a value return a list with all categories.""" result = [] for cat in self.all_categories: data = self.info[cat] lshift = data["lshift"] length = data["bit_length"] decoded = data["shifted"] move = lshift + length rest = value >> move << move norest = value - rest to_compare = norest >> lshift if to_compare == decoded: result.append(cat) return result
[docs] def decodeImage(self, image, qa_band): """Get an Image with one band per category in the Bit Reader. :param bit_reader: the bit reader :type bit_reader: BitReader :param qa_band: name of the band that holds the bit information :type qa_band: str :return: the image with the decode bands added """ options = ee.Dictionary(self.info) categories = ee.List(self.all_categories) def eachcat(cat, ini): ini = ee.Image(ini) qa = ini.select(qa_band) # get data for category data = ee.Dictionary(options.get(cat)) lshift = ee.Number(data.get("lshift")) length = ee.Number(data.get("bit_length")) decoded = ee.Number(data.get("shifted")) # move = places to move bits right and left back move = lshift.add(length) # move bits right and left rest = qa.rightShift(move).leftShift(move) # subtract the rest norest = qa.subtract(rest) # right shift to compare with decoded data to_compare = norest.rightShift(lshift) ## Image # compare if is equal, return 0 if not equal, 1 if equal mask = to_compare.eq(decoded) # rename to the name of the category qa_mask = mask.select([0], [cat]) return ini.addBands(qa_mask) return ee.Image(categories.iterate(eachcat, image)).select(categories)
[docs] def match(self, value, category): """given a value and a category return True if the value includes. that category, else False . """ encoded = self.decode(value) return category in encoded