Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/jpeg-xl/lib/jxl/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 68 kB image not shown  

Quelle  enc_ans.cc   Sprache: C

 
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include "lib/jxl/enc_ans.h"

#include <jxl/memory_manager.h>
#include <jxl/types.h>

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <limits>
#include <unordered_map>
#include <utility>
#include <vector>

#include "lib/jxl/ans_common.h"
#include "lib/jxl/base/bits.h"
#include "lib/jxl/base/fast_math-inl.h"
#include "lib/jxl/base/status.h"
#include "lib/jxl/dec_ans.h"
#include "lib/jxl/enc_ans_params.h"
#include "lib/jxl/enc_aux_out.h"
#include "lib/jxl/enc_cluster.h"
#include "lib/jxl/enc_context_map.h"
#include "lib/jxl/enc_fields.h"
#include "lib/jxl/enc_huffman.h"
#include "lib/jxl/enc_params.h"
#include "lib/jxl/fields.h"

namespace jxl {

namespace {

#if (!JXL_IS_DEBUG_BUILD)
constexpr
#endif
    bool ans_fuzzer_friendly_ = false;

const int kMaxNumSymbolsForSmallCode = 4;

void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
                       size_t alphabet_size, size_t log_alpha_size,
                       ANSEncSymbolInfo* info) {
  size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
  size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
  // create valid alias table for empty streams.
  for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
    const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
    info[s].freq_ = static_cast<uint16_t>(freq);
#ifdef USE_MULT_BY_RECIPROCAL
    if (freq != 0) {
      info[s].ifreq_ =
          ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
    } else {
      info[s].ifreq_ = 1;  // shouldn't matter (symbol shouldn't occur), but...
    }
#endif
    info[s].reverse_map_.resize(freq);
  }
  for (int i = 0; i < ANS_TAB_SIZE; i++) {
    AliasTable::Symbol s =
        AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
    info[s.value].reverse_map_[s.offset] = i;
  }
}

float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
                       size_t len) {
  float sum = 0.0f;
  int total_histogram = 0;
  int total_counts = 0;
  for (size_t i = 0; i < len; ++i) {
    total_histogram += histogram[i];
    total_counts += counts[i];
    if (histogram[i] > 0) {
      JXL_DASSERT(counts[i] > 0);
      // += histogram[i] * -log(counts[i]/total_counts)
      sum += histogram[i] *
             std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
    }
  }
  if (total_histogram > 0) {
    // Used only in assert.
    (void)total_counts;
    JXL_DASSERT(total_counts == ANS_TAB_SIZE);
  }
  return sum;
}

float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
  const float flat_bits = std::max(FastLog2f(len), 0.0f);
  float total_histogram = 0;
  for (size_t i = 0; i < len; ++i) {
    total_histogram += histogram[i];
  }
  return total_histogram * flat_bits;
}

// Static Huffman code for encoding logcounts. The last symbol is used as RLE
// sequence.
const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
    5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
};
const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
    17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
};

// Returns the difference between largest count that can be represented and is
// smaller than "count" and smallest representable count larger than "count".
int SmallestIncrement(uint32_t count, uint32_t shift) {
  int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
  int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
  return drop_bits < 0 ? 1 : (1 << drop_bits);
}

template <bool minimize_error_of_sum>
bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
                        uint32_t shift, int* omit_pos, ANSHistBin* counts) {
  int sum = 0;
  float sum_nonrounded = 0.0;
  int remainder_pos = 0;  // if all of them are handled in first loop
  int remainder_log = -1;
  for (int n = 0; n < max_symbol; ++n) {
    if (targets[n] > 0 && targets[n] < 1.0f) {
      counts[n] = 1;
      sum_nonrounded += targets[n];
      sum += counts[n];
    }
  }
  const float discount_ratio =
      (table_size - sum) / (table_size - sum_nonrounded);
  JXL_ENSURE(discount_ratio > 0);
  JXL_ENSURE(discount_ratio <= 1.0f);
  // Invariant for minimize_error_of_sum == true:
  // abs(sum - sum_nonrounded)
  //   <= SmallestIncrement(max(targets[])) + max_symbol
  for (int n = 0; n < max_symbol; ++n) {
    if (targets[n] >= 1.0f) {
      sum_nonrounded += targets[n];
      counts[n] =
          static_cast<ANSHistBin>(targets[n] * discount_ratio);  // truncate
      if (counts[n] == 0) counts[n] = 1;
      if (counts[n] == table_size) counts[n] = table_size - 1;
      // Round the count to the closest nonzero multiple of SmallestIncrement
      // (when minimize_error_of_sum is false) or one of two closest so as to
      // keep the sum as close as possible to sum_nonrounded.
      int inc = SmallestIncrement(counts[n], shift);
      counts[n] -= counts[n] & (inc - 1);
      // TODO(robryk): Should we rescale targets[n]?
      const int target = minimize_error_of_sum
                             ? (static_cast<int>(sum_nonrounded) - sum)
                             : static_cast<int>(targets[n]);
      if (counts[n] == 0 ||
          (target >= counts[n] + inc / 2 && counts[n] + inc < table_size)) {
        counts[n] += inc;
      }
      sum += counts[n];
      const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
      if (count_log > remainder_log) {
        remainder_pos = n;
        remainder_log = count_log;
      }
    }
  }
  JXL_ENSURE(remainder_pos != -1);
  // NOTE: This is the only place where counts could go negative. We could
  // detect that, return false and make ANSHistBin uint32_t.
  counts[remainder_pos] -= sum - table_size;
  *omit_pos = remainder_pos;
  return counts[remainder_pos] > 0;
}

Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
                       const int precision_bits, uint32_t shift,
                       int* num_symbols, int* symbols) {
  const int32_t table_size = 1 << precision_bits;  // target sum / table size
  uint64_t total = 0;
  int max_symbol = 0;
  int symbol_count = 0;
  for (int n = 0; n < length; ++n) {
    total += counts[n];
    if (counts[n] > 0) {
      if (symbol_count < kMaxNumSymbolsForSmallCode) {
        symbols[symbol_count] = n;
      }
      ++symbol_count;
      max_symbol = n + 1;
    }
  }
  *num_symbols = symbol_count;
  if (symbol_count == 0) {
    return true;
  }
  if (symbol_count == 1) {
    counts[symbols[0]] = table_size;
    return true;
  }
  if (symbol_count > table_size)
    return JXL_FAILURE("Too many entries in an ANS histogram");

  const float norm = 1.f * table_size / total;
  std::vector<float> targets(max_symbol);
  for (size_t n = 0; n < targets.size(); ++n) {
    targets[n] = norm * counts[n];
  }
  if (!RebalanceHistogram<false>(targets.data(), max_symbol, table_size, shift,
                                 omit_pos, counts)) {
    // Use an alternative rebalancing mechanism if the one above failed
    // to create a histogram that is positive wherever the original one was.
    if (!RebalanceHistogram<true>(targets.data(), max_symbol, table_size, shift,
                                  omit_pos, counts)) {
      return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
    }
  }
  return true;
}

struct SizeWriter {
  size_t size = 0;
  void Write(size_t num, size_t bits) { size += num; }
};

template <typename Writer>
void StoreVarLenUint8(size_t n, Writer* writer) {
  JXL_DASSERT(n <= 255);
  if (n == 0) {
    writer->Write(1, 0);
  } else {
    writer->Write(1, 1);
    size_t nbits = FloorLog2Nonzero(n);
    writer->Write(3, nbits);
    writer->Write(nbits, n - (1ULL << nbits));
  }
}

template <typename Writer>
void StoreVarLenUint16(size_t n, Writer* writer) {
  JXL_DASSERT(n <= 65535);
  if (n == 0) {
    writer->Write(1, 0);
  } else {
    writer->Write(1, 1);
    size_t nbits = FloorLog2Nonzero(n);
    writer->Write(4, nbits);
    writer->Write(nbits, n - (1ULL << nbits));
  }
}

template <typename Writer>
bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
                  const int omit_pos, const int num_symbols, uint32_t shift,
                  const int* symbols, Writer* writer) {
  bool ok = true;
  if (num_symbols <= 2) {
    // Small tree marker to encode 1-2 symbols.
    writer->Write(1, 1);
    if (num_symbols == 0) {
      writer->Write(1, 0);
      StoreVarLenUint8(0, writer);
    } else {
      writer->Write(1, num_symbols - 1);
      for (int i = 0; i < num_symbols; ++i) {
        StoreVarLenUint8(symbols[i], writer);
      }
    }
    if (num_symbols == 2) {
      writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
    }
  } else {
    // Mark non-small tree.
    writer->Write(1, 0);
    // Mark non-flat histogram.
    writer->Write(1, 0);

    // Precompute sequences for RLE encoding. Contains the number of identical
    // values starting at a given index. Only contains the value at the first
    // element of the series.
    std::vector<uint32_t> same(alphabet_size, 0);
    int last = 0;
    for (int i = 1; i < alphabet_size; i++) {
      // Store the sequence length once different symbol reached, or we're at
      // the end, or the length is longer than we can encode, or we are at
      // the omit_pos. We don't support including the omit_pos in an RLE
      // sequence because this value may use a different amount of log2 bits
      // than standard, it is too complex to handle in the decoder.
      if (counts[i] != counts[last] || i + 1 == alphabet_size ||
          (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
        same[last] = (i - last);
        last = i + 1;
      }
    }

    int length = 0;
    std::vector<int> logcounts(alphabet_size);
    int omit_log = 0;
    for (int i = 0; i < alphabet_size; ++i) {
      JXL_ENSURE(counts[i] <= ANS_TAB_SIZE);
      JXL_ENSURE(counts[i] >= 0);
      if (i == omit_pos) {
        length = i + 1;
      } else if (counts[i] > 0) {
        logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
        length = i + 1;
        if (i < omit_pos) {
          omit_log = std::max(omit_log, logcounts[i] + 1);
        } else {
          omit_log = std::max(omit_log, logcounts[i]);
        }
      }
    }
    logcounts[omit_pos] = omit_log;

    // Elias gamma-like code for shift. Only difference is that if the number
    // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
    // the terminating 0 in unary coding.
    int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
    int log = FloorLog2Nonzero(shift + 1);
    writer->Write(log, (1 << log) - 1);
    if (log != upper_bound_log) writer->Write(1, 0);
    writer->Write(log, ((1 << log) - 1) & (shift + 1));

    // Since num_symbols >= 3, we know that length >= 3, therefore we encode
    // length - 3.
    if (length - 3 > 255) {
      // Pretend that everything is OK, but complain about correctness later.
      StoreVarLenUint8(255, writer);
      ok = false;
    } else {
      StoreVarLenUint8(length - 3, writer);
    }

    // The logcount values are encoded with a static Huffman code.
    static const size_t kMinReps = 4;
    size_t rep = ANS_LOG_TAB_SIZE + 1;
    for (int i = 0; i < length; ++i) {
      if (i > 0 && same[i - 1] > kMinReps) {
        // Encode the RLE symbol and skip the repeated ones.
        writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
        StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
        i += same[i - 1] - 2;
        continue;
      }
      writer->Write(kLogCountBitLengths[logcounts[i]],
                    kLogCountSymbols[logcounts[i]]);
    }
    for (int i = 0; i < length; ++i) {
      if (i > 0 && same[i - 1] > kMinReps) {
        // Skip symbols encoded by RLE.
        i += same[i - 1] - 2;
        continue;
      }
      if (logcounts[i] > 1 && i != omit_pos) {
        int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
        int drop_bits = logcounts[i] - 1 - bitcount;
        JXL_ENSURE((counts[i] & ((1 << drop_bits) - 1)) == 0);
        writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
      }
    }
  }
  return ok;
}

void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
  // Mark non-small tree.
  writer->Write(1, 0);
  // Mark uniform histogram.
  writer->Write(1, 1);
  JXL_DASSERT(alphabet_size > 0);
  // Encode alphabet size.
  StoreVarLenUint8(alphabet_size - 1, writer);
}

StatusOr<float> ComputeHistoAndDataCost(const ANSHistBin* histogram,
                                        size_t alphabet_size, uint32_t method) {
  if (method == 0) {  // Flat code
    return ANS_LOG_TAB_SIZE + 2 +
           EstimateDataBitsFlat(histogram, alphabet_size);
  }
  // Non-flat: shift = method-1.
  uint32_t shift = method - 1;
  std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
  int omit_pos = 0;
  int num_symbols;
  int symbols[kMaxNumSymbolsForSmallCode] = {};
  JXL_RETURN_IF_ERROR(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
                                      ANS_LOG_TAB_SIZE, shift, &num_symbols,
                                      symbols));
  SizeWriter writer;
  // Ignore the correctness, no real encoding happens at this stage.
  (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
                     symbols, &writer);
  return writer.size +
         EstimateDataBits(histogram, counts.data(), alphabet_size);
}

StatusOr<uint32_t> ComputeBestMethod(
    const ANSHistBin* histogram, size_t alphabet_size, float* cost,
    HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
  uint32_t method = 0;
  JXL_ASSIGN_OR_RETURN(float fcost,
                       ComputeHistoAndDataCost(histogram, alphabet_size, 0));
  auto try_shift = [&](size_t shift) -> Status {
    JXL_ASSIGN_OR_RETURN(
        float c, ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1));
    if (c < fcost) {
      method = shift + 1;
      fcost = c;
    }
    return true;
  };
  switch (ans_histogram_strategy) {
    case HistogramParams::ANSHistogramStrategy::kPrecise: {
      for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) {
        JXL_RETURN_IF_ERROR(try_shift(shift));
      }
      break;
    }
    case HistogramParams::ANSHistogramStrategy::kApproximate: {
      for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) {
        JXL_RETURN_IF_ERROR(try_shift(shift));
      }
      break;
    }
    case HistogramParams::ANSHistogramStrategy::kFast: {
      JXL_RETURN_IF_ERROR(try_shift(0));
      JXL_RETURN_IF_ERROR(try_shift(ANS_LOG_TAB_SIZE / 2));
      JXL_RETURN_IF_ERROR(try_shift(ANS_LOG_TAB_SIZE));
      break;
    }
  };
  *cost = fcost;
  return method;
}

}  // namespace

// Returns an estimate of the cost of encoding this histogram and the
// corresponding data.
StatusOr<size_t> BuildAndStoreANSEncodingData(
    JxlMemoryManager* memory_manager,
    HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
    const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
    bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
  if (use_prefix_code) {
    size_t cost = 0;
    if (alphabet_size <= 1) return 0;
    std::vector<uint32_t> histo(alphabet_size);
    for (size_t i = 0; i < alphabet_size; i++) {
      histo[i] = histogram[i];
      JXL_ENSURE(histogram[i] >= 0);
    }
    {
      std::vector<uint8_t> depths(alphabet_size);
      std::vector<uint16_t> bits(alphabet_size);
      if (writer == nullptr) {
        BitWriter tmp_writer{memory_manager};
        JXL_RETURN_IF_ERROR(tmp_writer.WithMaxBits(
            8 * alphabet_size + 8,  // safe upper bound
            LayerType::Header, /*aux_out=*/nullptr, [&] {
              return BuildAndStoreHuffmanTree(histo.data(), alphabet_size,
                                              depths.data(), bits.data(),
                                              &tmp_writer);
            }));
        cost = tmp_writer.BitsWritten();
      } else {
        size_t start = writer->BitsWritten();
        JXL_RETURN_IF_ERROR(BuildAndStoreHuffmanTree(
            histo.data(), alphabet_size, depths.data(), bits.data(), writer));
        cost = writer->BitsWritten() - start;
      }
      for (size_t i = 0; i < alphabet_size; i++) {
        info[i].bits = depths[i] == 0 ? 0 : bits[i];
        info[i].depth = depths[i];
      }
    }
    // Estimate data cost.
    for (size_t i = 0; i < alphabet_size; i++) {
      cost += histogram[i] * info[i].depth;
    }
    return cost;
  }
  JXL_ENSURE(alphabet_size <= ANS_TAB_SIZE);
  float fcost;
  JXL_ASSIGN_OR_RETURN(uint32_t method,
                       ComputeBestMethod(histogram, alphabet_size, &fcost,
                                         ans_histogram_strategy));
  JXL_ENSURE(fcost >= 0);
  int num_symbols;
  int symbols[kMaxNumSymbolsForSmallCode] = {};
  std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
  if (!counts.empty()) {
    size_t sum = 0;
    for (int count : counts) {
      sum += count;
    }
    if (sum == 0) {
      counts[0] = ANS_TAB_SIZE;
    }
  }
  int omit_pos = 0;
  uint32_t shift = method - 1;
  if (method == 0) {
    JXL_ENSURE(alphabet_size > 0);
    counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
  } else {
    JXL_RETURN_IF_ERROR(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
                                        ANS_LOG_TAB_SIZE, shift, &num_symbols,
                                        symbols));
  }
  AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
  JXL_RETURN_IF_ERROR(
      InitAliasTable(counts, ANS_LOG_TAB_SIZE, log_alpha_size, a));
  ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
  if (writer != nullptr) {
    if (method == 0) {
      JXL_ENSURE(alphabet_size > 0);
      EncodeFlatHistogram(alphabet_size, writer);
    } else {
      if (!EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols,
                        method - 1, symbols, writer)) {
        return JXL_FAILURE("EncodeCounts failed");
      }
    }
  }
  return static_cast<size_t>(fcost);
}

StatusOr<float> ANSPopulationCost(const ANSHistBin* data,
                                  size_t alphabet_size) {
  float cost = 0.0f;
  if (ANS_MAX_ALPHABET_SIZE < alphabet_size) {
    return std::numeric_limits<float>::max();
  }
  JXL_ASSIGN_OR_RETURN(
      uint32_t method,
      ComputeBestMethod(data, alphabet_size, &cost,
                        HistogramParams::ANSHistogramStrategy::kFast));
  (void)method;
  return cost;
}

template <typename Writer>
void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
                      size_t log_alpha_size) {
  writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
                uint_config.split_exponent);
  if (uint_config.split_exponent == log_alpha_size) {
    return;  // msb/lsb don't matter.
  }
  size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
  writer->Write(nbits, uint_config.msb_in_token);
  nbits = CeilLog2Nonzero(uint_config.split_exponent -
                          uint_config.msb_in_token + 1);
  writer->Write(nbits, uint_config.lsb_in_token);
}
template <typename Writer>
void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
                       Writer* writer, size_t log_alpha_size) {
  // TODO(veluca): RLE?
  for (const auto& cfg : uint_config) {
    EncodeUintConfig(cfg, writer, log_alpha_size);
  }
}
template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
                                BitWriter*, size_t);

namespace {

Status ChooseUintConfigs(const HistogramParams& params,
                         const std::vector<std::vector<Token>>& tokens,
                         const std::vector<uint8_t>& context_map,
                         std::vector<Histogram>* clustered_histograms,
                         EntropyEncodingData* codes, size_t* log_alpha_size) {
  codes->uint_config.resize(clustered_histograms->size());
  if (params.uint_method == HistogramParams::HybridUintMethod::kNone) {
    return true;
  }
  if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
    codes->uint_config.clear();
    codes->uint_config.resize(clustered_histograms->size(),
                              HybridUintConfig(0, 0, 0));
    return true;
  }
  if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
    codes->uint_config.clear();
    codes->uint_config.resize(clustered_histograms->size(),
                              HybridUintConfig(2, 0, 1));
    return true;
  }

  // If the uint config is adaptive, just stick with the default in streaming
  // mode.
  if (params.streaming_mode) {
    return true;
  }

  // Brute-force method that tries a few options.
  std::vector<HybridUintConfig> configs;
  if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
    configs = {
        HybridUintConfig(4, 2, 0),  // default
        HybridUintConfig(4, 1, 0),  // less precise
        HybridUintConfig(4, 2, 1),  // add sign
        HybridUintConfig(4, 2, 2),  // add sign+parity
        HybridUintConfig(4, 1, 2),  // add parity but less msb
        // Same as above, but more direct coding.
        HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
        HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
        HybridUintConfig(5, 1, 2),
        // Same as above, but less direct coding.
        HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
        HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
        // For near-lossless.
        HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
        HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
        HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
        // Other
        HybridUintConfig(0, 0, 0),   // varlenuint
        HybridUintConfig(2, 0, 1),   // works well for ctx map
        HybridUintConfig(7, 0, 0),   // direct coding
        HybridUintConfig(8, 0, 0),   // direct coding
        HybridUintConfig(9, 0, 0),   // direct coding
        HybridUintConfig(10, 0, 0),  // direct coding
        HybridUintConfig(11, 0, 0),  // direct coding
        HybridUintConfig(12, 0, 0),  // direct coding
    };
  } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
    configs = {
        HybridUintConfig(4, 2, 0),  // default
        HybridUintConfig(4, 1, 2),  // add parity but less msb
        HybridUintConfig(0, 0, 0),  // smallest histograms
        HybridUintConfig(2, 0, 1),  // works well for ctx map
    };
  }

  std::vector<float> costs(clustered_histograms->size(),
                           std::numeric_limits<float>::max());
  std::vector<uint32_t> extra_bits(clustered_histograms->size());
  std::vector<uint8_t> is_valid(clustered_histograms->size());
  size_t max_alpha =
      codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
  for (HybridUintConfig cfg : configs) {
    std::fill(is_valid.begin(), is_valid.end(), true);
    std::fill(extra_bits.begin(), extra_bits.end(), 0);

    for (auto& histo : *clustered_histograms) {
      histo.Clear();
    }
    for (const auto& stream : tokens) {
      for (const auto& token : stream) {
        // TODO(veluca): do not ignore lz77 commands.
        if (token.is_lz77_length) continue;
        size_t histo = context_map[token.context];
        uint32_t tok, nbits, bits;
        cfg.Encode(token.value, &tok, &nbits, &bits);
        if (tok >= max_alpha ||
            (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
          is_valid[histo] = JXL_FALSE;
          continue;
        }
        extra_bits[histo] += nbits;
        (*clustered_histograms)[histo].Add(tok);
      }
    }

    for (size_t i = 0; i < clustered_histograms->size(); i++) {
      if (!is_valid[i]) continue;
      JXL_ASSIGN_OR_RETURN(float cost,
                           (*clustered_histograms)[i].PopulationCost());
      cost += extra_bits[i];
      // add signaling cost of the hybriduintconfig itself
      cost += CeilLog2Nonzero(cfg.split_exponent + 1);
      cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1);
      if (cost < costs[i]) {
        codes->uint_config[i] = cfg;
        costs[i] = cost;
      }
    }
  }

  // Rebuild histograms.
  for (auto& histo : *clustered_histograms) {
    histo.Clear();
  }
  *log_alpha_size = 4;
  for (const auto& stream : tokens) {
    for (const auto& token : stream) {
      uint32_t tok, nbits, bits;
      size_t histo = context_map[token.context];
      (token.is_lz77_length ? codes->lz77.length_uint_config
                            : codes->uint_config[histo])
          .Encode(token.value, &tok, &nbits, &bits);
      tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
      (*clustered_histograms)[histo].Add(tok);
      while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
    }
  }
  size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
  JXL_ENSURE(*log_alpha_size <= max_log_alpha_size);
  return true;
}

Histogram HistogramFromSymbolInfo(
    const std::vector<ANSEncSymbolInfo>& encoding_info, bool use_prefix_code) {
  Histogram histo;
  histo.data_.resize(DivCeil(encoding_info.size(), Histogram::kRounding) *
                     Histogram::kRounding);
  histo.total_count_ = 0;
  for (size_t i = 0; i < encoding_info.size(); ++i) {
    const ANSEncSymbolInfo& info = encoding_info[i];
    int count = use_prefix_code
                    ? (info.depth ? (1u << (PREFIX_MAX_BITS - info.depth)) : 0)
                    : info.freq_;
    histo.data_[i] = count;
    histo.total_count_ += count;
  }
  return histo;
}

class HistogramBuilder {
 public:
  explicit HistogramBuilder(const size_t num_contexts)
      : histograms_(num_contexts) {}

  void VisitSymbol(int symbol, size_t histo_idx) {
    JXL_DASSERT(histo_idx < histograms_.size());
    histograms_[histo_idx].Add(symbol);
  }

  // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
  // Returns cost (in bits).
  StatusOr<size_t> BuildAndStoreEntropyCodes(
      JxlMemoryManager* memory_manager, const HistogramParams& params,
      const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
      std::vector<uint8_t>* context_map, BitWriter* writer, LayerType layer,
      AuxOut* aux_out) const {
    const size_t prev_histograms = codes->encoding_info.size();
    std::vector<Histogram> clustered_histograms;
    for (size_t i = 0; i < prev_histograms; ++i) {
      clustered_histograms.push_back(HistogramFromSymbolInfo(
          codes->encoding_info[i], codes->use_prefix_code));
    }
    size_t context_offset = context_map->size();
    context_map->resize(context_offset + histograms_.size());
    if (histograms_.size() > 1) {
      if (!ans_fuzzer_friendly_) {
        std::vector<uint32_t> histogram_symbols;
        JXL_RETURN_IF_ERROR(
            ClusterHistograms(params, histograms_, kClustersLimit,
                              &clustered_histograms, &histogram_symbols));
        for (size_t c = 0; c < histograms_.size(); ++c) {
          (*context_map)[context_offset + c] =
              static_cast<uint8_t>(histogram_symbols[c]);
        }
      } else {
        JXL_ENSURE(codes->encoding_info.empty());
        fill(context_map->begin(), context_map->end(), 0);
        size_t max_symbol = 0;
        for (const Histogram& h : histograms_) {
          max_symbol = std::max(h.data_.size(), max_symbol);
        }
        size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
        clustered_histograms.resize(1);
        clustered_histograms[0].Clear();
        for (size_t i = 0; i < num_symbols; i++) {
          clustered_histograms[0].Add(i);
        }
      }
      if (writer != nullptr) {
        JXL_RETURN_IF_ERROR(EncodeContextMap(
            *context_map, clustered_histograms.size(), writer, layer, aux_out));
      }
    } else {
      JXL_ENSURE(codes->encoding_info.empty());
      clustered_histograms.push_back(histograms_[0]);
    }
    if (aux_out != nullptr) {
      for (size_t i = prev_histograms; i < clustered_histograms.size(); ++i) {
        aux_out->layer(layer).clustered_entropy +=
            clustered_histograms[i].ShannonEntropy();
      }
    }
    size_t log_alpha_size = codes->lz77.enabled ? 8 : 7;  // Sane default.
    if (ans_fuzzer_friendly_) {
      codes->uint_config.clear();
      codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
    } else {
      JXL_RETURN_IF_ERROR(ChooseUintConfigs(params, tokens, *context_map,
                                            &clustered_histograms, codes,
                                            &log_alpha_size));
    }
    if (log_alpha_size < 5) log_alpha_size = 5;
    if (params.streaming_mode) {
      // TODO(szabadka) Figure out if we can use lower values here.
      log_alpha_size = 8;
    }
    SizeWriter size_writer;  // Used if writer == nullptr to estimate costs.
    size_t cost = 1;
    if (writer) writer->Write(1, TO_JXL_BOOL(codes->use_prefix_code));

    if (codes->use_prefix_code) {
      log_alpha_size = PREFIX_MAX_BITS;
    } else {
      cost += 2;
    }
    if (writer == nullptr) {
      EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
    } else {
      if (!codes->use_prefix_code) writer->Write(2, log_alpha_size - 5);
      EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
    }
    if (codes->use_prefix_code) {
      for (const auto& histo : clustered_histograms) {
        size_t alphabet_size = histo.alphabet_size();
        if (writer) {
          StoreVarLenUint16(alphabet_size - 1, writer);
        } else {
          StoreVarLenUint16(alphabet_size - 1, &size_writer);
        }
      }
    }
    cost += size_writer.size;
    for (size_t c = prev_histograms; c < clustered_histograms.size(); ++c) {
      size_t alphabet_size = clustered_histograms[c].alphabet_size();
      codes->encoding_info.emplace_back();
      codes->encoding_info.back().resize(alphabet_size);
      BitWriter* histo_writer = writer;
      if (params.streaming_mode) {
        codes->encoded_histograms.emplace_back(memory_manager);
        histo_writer = &codes->encoded_histograms.back();
      }
      const auto& body = [&]() -> Status {
        JXL_ASSIGN_OR_RETURN(
            size_t ans_cost,
            BuildAndStoreANSEncodingData(
                memory_manager, params.ans_histogram_strategy,
                clustered_histograms[c].data_.data(), alphabet_size,
                log_alpha_size, codes->use_prefix_code,
                codes->encoding_info.back().data(), histo_writer));
        cost += ans_cost;
        return true;
      };
      if (histo_writer) {
        JXL_RETURN_IF_ERROR(histo_writer->WithMaxBits(
            256 + alphabet_size * 24, layer, aux_out, body,
            /*finished_histogram=*/true));
      } else {
        JXL_RETURN_IF_ERROR(body());
      }
      if (params.streaming_mode) {
        JXL_RETURN_IF_ERROR(writer->AppendUnaligned(*histo_writer));
      }
    }
    return cost;
  }

  const Histogram& Histo(size_t i) const { return histograms_[i]; }

 private:
  std::vector<Histogram> histograms_;
};

class SymbolCostEstimator {
 public:
  SymbolCostEstimator(size_t num_contexts, bool force_huffman,
                      const std::vector<std::vector<Token>>& tokens,
                      const LZ77Params& lz77) {
    HistogramBuilder builder(num_contexts);
    // Build histograms for estimating lz77 savings.
    HybridUintConfig uint_config;
    for (const auto& stream : tokens) {
      for (const auto& token : stream) {
        uint32_t tok, nbits, bits;
        (token.is_lz77_length ? lz77.length_uint_config : uint_config)
            .Encode(token.value, &tok, &nbits, &bits);
        tok += token.is_lz77_length ? lz77.min_symbol : 0;
        builder.VisitSymbol(tok, token.context);
      }
    }
    max_alphabet_size_ = 0;
    for (size_t i = 0; i < num_contexts; i++) {
      max_alphabet_size_ =
          std::max(max_alphabet_size_, builder.Histo(i).data_.size());
    }
    bits_.resize(num_contexts * max_alphabet_size_);
    // TODO(veluca): SIMD?
    add_symbol_cost_.resize(num_contexts);
    for (size_t i = 0; i < num_contexts; i++) {
      float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
      float total_cost = 0;
      for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
        size_t cnt = builder.Histo(i).data_[j];
        float cost = 0;
        if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
          cost = -FastLog2f(cnt * inv_total);
          if (force_huffman) cost = std::ceil(cost);
        } else if (cnt == 0) {
          cost = ANS_LOG_TAB_SIZE;  // Highest possible cost.
        }
        bits_[i * max_alphabet_size_ + j] = cost;
        total_cost += cost * builder.Histo(i).data_[j];
      }
      // Penalty for adding a lz77 symbol to this contest (only used for static
      // cost model). Higher penalty for contexts that have a very low
      // per-symbol entropy.
      add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
    }
  }
  float Bits(size_t ctx, size_t sym) const {
    return bits_[ctx * max_alphabet_size_ + sym];
  }
  float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
    uint32_t nbits, bits, tok;
    lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
    tok += lz77.min_symbol;
    return nbits + Bits(ctx, tok);
  }
  float DistCost(size_t len, const LZ77Params& lz77) const {
    uint32_t nbits, bits, tok;
    HybridUintConfig().Encode(len, &tok, &nbits, &bits);
    return nbits + Bits(lz77.nonserialized_distance_context, tok);
  }
  float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }

 private:
  size_t max_alphabet_size_;
  std::vector<float> bits_;
  std::vector<float> add_symbol_cost_;
};

void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
                   const std::vector<std::vector<Token>>& tokens,
                   LZ77Params& lz77,
                   std::vector<std::vector<Token>>& tokens_lz77) {
  // TODO(veluca): tune heuristics here.
  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
  float bit_decrease = 0;
  size_t total_symbols = 0;
  tokens_lz77.resize(tokens.size());
  std::vector<float> sym_cost;
  HybridUintConfig uint_config;
  for (size_t stream = 0; stream < tokens.size(); stream++) {
    size_t distance_multiplier =
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
    const auto& in = tokens[stream];
    auto& out = tokens_lz77[stream];
    total_symbols += in.size();
    // Cumulative sum of bit costs.
    sym_cost.resize(in.size() + 1);
    for (size_t i = 0; i < in.size(); i++) {
      uint32_t tok, nbits, unused_bits;
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
    }
    out.reserve(in.size());
    for (size_t i = 0; i < in.size(); i++) {
      size_t num_to_copy = 0;
      size_t distance_symbol = 0;  // 1 for RLE.
      if (distance_multiplier != 0) {
        distance_symbol = 1;  // Special distance 1 if enabled.
        JXL_DASSERT(kSpecialDistances[1][0] == 1);
        JXL_DASSERT(kSpecialDistances[1][1] == 0);
      }
      if (i > 0) {
        for (; i + num_to_copy < in.size(); num_to_copy++) {
          if (in[i + num_to_copy].value != in[i - 1].value) {
            break;
          }
        }
      }
      if (num_to_copy == 0) {
        out.push_back(in[i]);
        continue;
      }
      float cost = sym_cost[i + num_to_copy] - sym_cost[i];
      // This subtraction might overflow, but that's OK.
      size_t lz77_len = num_to_copy - lz77.min_length;
      float lz77_cost = num_to_copy >= lz77.min_length
                            ? CeilLog2Nonzero(lz77_len + 1) + 1
                            : 0;
      if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
        for (size_t j = 0; j < num_to_copy; j++) {
          out.push_back(in[i + j]);
        }
        i += num_to_copy - 1;
        continue;
      }
      // Output the LZ77 length
      out.emplace_back(in[i].context, lz77_len);
      out.back().is_lz77_length = true;
      i += num_to_copy - 1;
      bit_decrease += cost - lz77_cost;
      // Output the LZ77 copy distance.
      out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
    }
  }

  if (bit_decrease > total_symbols * 0.2 + 16) {
    lz77.enabled = true;
  }
}

// Hash chain for LZ77 matching
struct HashChain {
  size_t size_;
  std::vector<uint32_t> data_;

  unsigned hash_num_values_ = 32768;
  unsigned hash_mask_ = hash_num_values_ - 1;
  unsigned hash_shift_ = 5;

  std::vector<int> head;
  std::vector<uint32_t> chain;
  std::vector<int> val;

  // Speed up repetitions of zero
  std::vector<int> headz;
  std::vector<uint32_t> chainz;
  std::vector<uint32_t> zeros;
  uint32_t numzeros = 0;

  size_t window_size_;
  size_t window_mask_;
  size_t min_length_;
  size_t max_length_;

  // Map of special distance codes.
  std::unordered_map<intint> special_dist_table_;
  size_t num_special_distances_ = 0;

  uint32_t maxchainlength = 256;  // window_size_ to allow all

  HashChain(const Token* data, size_t size, size_t window_size,
            size_t min_length, size_t max_length, size_t distance_multiplier)
      : size_(size),
        window_size_(window_size),
        window_mask_(window_size - 1),
        min_length_(min_length),
        max_length_(max_length) {
    data_.resize(size);
    for (size_t i = 0; i < size; i++) {
      data_[i] = data[i].value;
    }

    head.resize(hash_num_values_, -1);
    val.resize(window_size_, -1);
    chain.resize(window_size_);
    for (uint32_t i = 0; i < window_size_; ++i) {
      chain[i] = i;  // same value as index indicates uninitialized
    }

    zeros.resize(window_size_);
    headz.resize(window_size_ + 1, -1);
    chainz.resize(window_size_);
    for (uint32_t i = 0; i < window_size_; ++i) {
      chainz[i] = i;
    }
    // Translate distance to special distance code.
    if (distance_multiplier) {
      // Count down, so if due to small distance multiplier multiple distances
      // map to the same code, the smallest code will be used in the end.
      for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
        special_dist_table_[SpecialDistance(i, distance_multiplier)] = i;
      }
      num_special_distances_ = kNumSpecialDistances;
    }
  }

  uint32_t GetHash(size_t pos) const {
    uint32_t result = 0;
    if (pos + 2 < size_) {
      // TODO(lode): take the MSB's of the uint32_t values into account as well,
      // given that the hash code itself is less than 32 bits.
      result ^= static_cast<uint32_t>(data_[pos + 0] << 0u);
      result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_);
      result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2));
    } else {
      // No need to compute hash of last 2 bytes, the length 2 is too short.
      return 0;
    }
    return result & hash_mask_;
  }

  uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
    size_t end = pos + window_size_;
    if (end > size_) end = size_;
    if (prevzeros > 0) {
      if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
          end == pos + window_size_) {
        return prevzeros;
      } else {
        return prevzeros - 1;
      }
    }
    uint32_t num = 0;
    while (pos + num < end && data_[pos + num] == 0) num++;
    return num;
  }

  void Update(size_t pos) {
    uint32_t hashval = GetHash(pos);
    uint32_t wpos = pos & window_mask_;

    val[wpos] = static_cast<int>(hashval);
    if (head[hashval] != -1) chain[wpos] = head[hashval];
    head[hashval] = wpos;

    if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
    numzeros = CountZeros(pos, numzeros);

    zeros[wpos] = numzeros;
    if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
    headz[numzeros] = wpos;
  }

  void Update(size_t pos, size_t len) {
    for (size_t i = 0; i < len; i++) {
      Update(pos + i);
    }
  }

  template <typename CB>
  void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
    uint32_t wpos = pos & window_mask_;
    uint32_t hashval = GetHash(pos);
    uint32_t hashpos = chain[wpos];

    int prev_dist = 0;
    int end = std::min<int>(pos + max_length_, size_);
    uint32_t chainlength = 0;
    uint32_t best_len = 0;
    for (;;) {
      int dist = (hashpos <= wpos) ? (wpos - hashpos)
                                   : (wpos - hashpos + window_mask_ + 1);
      if (dist < prev_dist) break;
      prev_dist = dist;
      uint32_t len = 0;
      if (dist > 0) {
        int i = pos;
        int j = pos - dist;
        if (numzeros > 3) {
          int r = std::min<int>(numzeros - 1, zeros[hashpos]);
          if (i + r >= end) r = end - i - 1;
          i += r;
          j += r;
        }
        while (i < end && data_[i] == data_[j]) {
          i++;
          j++;
        }
        len = i - pos;
        // This can trigger even if the new length is slightly smaller than the
        // best length, because it is possible for a slightly cheaper distance
        // symbol to occur.
        if (len >= min_length_ && len + 2 >= best_len) {
          auto it = special_dist_table_.find(dist);
          int dist_symbol = (it == special_dist_table_.end())
                                ? (num_special_distances_ + dist - 1)
                                : it->second;
          found_match(len, dist_symbol);
          if (len > best_len) best_len = len;
        }
      }

      chainlength++;
      if (chainlength >= maxchainlength) break;

      if (numzeros >= 3 && len > numzeros) {
        if (hashpos == chainz[hashpos]) break;
        hashpos = chainz[hashpos];
        if (zeros[hashpos] != numzeros) break;
      } else {
        if (hashpos == chain[hashpos]) break;
        hashpos = chain[hashpos];
        if (val[hashpos] != static_cast<int>(hashval)) {
          // outdated hash value
          break;
        }
      }
    }
  }
  void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
                 size_t* result_len) const {
    *result_dist_symbol = 0;
    *result_len = 1;
    FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
      if (len > *result_len ||
          (len == *result_len && *result_dist_symbol > dist_symbol)) {
        *result_len = len;
        *result_dist_symbol = dist_symbol;
      }
    });
  }
};

float LenCost(size_t len) {
  uint32_t nbits, bits, tok;
  HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
  constexpr float kCostTable[] = {
      2.797667318563126,  3.213177690381199,  2.5706009246743737,
      2.408392498667534,  2.829649191872326,  3.3923087753324577,
      4.029267451554331,  4.415576699706408,  4.509357574741465,
      9.21481543803004,   10.020590190114898, 11.858671627804766,
      12.45853300490526,  11.713105831990857, 12.561996324849314,
      13.775477692278367, 13.174027068768641,
  };
  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
  if (tok >= table_size) tok = table_size - 1;
  return kCostTable[tok] + nbits;
}

// TODO(veluca): this does not take into account usage or non-usage of distance
// multipliers.
float DistCost(size_t dist) {
  uint32_t nbits, bits, tok;
  HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
  constexpr float kCostTable[] = {
      6.368282626312716,  5.680793277090298,  8.347404197105247,
      7.641619201599141,  6.914328374119438,  7.959808291537444,
      8.70023120759855,   8.71378518934703,   9.379132523982769,
      9.110472749092708,  9.159029569270908,  9.430936766731973,
      7.278284055315169,  7.8278514904267755, 10.026641158289236,
      9.976049229827066,  9.64351607048908,   9.563403863480442,
      10.171474111762747, 10.45950155077234,  9.994813912104219,
      10.322524683741156, 8.465808729388186,  8.756254166066853,
      10.160930174662234, 10.247329273413435, 10.04090403724809,
      10.129398517544082, 9.342311691539546,  9.07608009102374,
      10.104799540677513, 10.378079384990906, 10.165828974075072,
      10.337595322341553, 7.940557464567944,  10.575665823319431,
      11.023344321751955, 10.736144698831827, 11.118277044595054,
      7.468468230648442,  10.738305230932939, 10.906980780216568,
      10.163468216353817, 10.17805759656433,  11.167283670483565,
      11.147050200274544, 10.517921919244333, 10.651764778156886,
      10.17074446448919,  11.217636876224745, 11.261630721139484,
      11.403140815247259, 10.892472096873417, 11.1859607804481,
      8.017346947551262,  7.895143720278828,  11.036577113822025,
      11.170562110315794, 10.326988722591086, 10.40872184751056,
      11.213498225466386, 11.30580635516863,  10.672272515665442,
      10.768069466228063, 11.145257364153565, 11.64668307145549,
      10.593156194627339, 11.207499484844943, 10.767517766396908,
      10.826629811407042, 10.737764794499988, 10.6200448518045,
      10.191315385198092, 8.468384171390085,  11.731295299170432,
      11.824619886654398, 10.41518844301179,  10.16310536548649,
      10.539423685097576, 10.495136599328031, 10.469112847728267,
      11.72057686174922,  10.910326337834674, 11.378921834673758,
      11.847759036098536, 11.92071647623854,  10.810628276345282,
      11.008601085273893, 11.910326337834674, 11.949212023423133,
      11.298614839104337, 11.611603659010392, 10.472930394619985,
      11.835564720850282, 11.523267392285337, 12.01055816679611,
      8.413029688994023,  11.895784139536406, 11.984679534970505,
      11.220654278717394, 11.716311684833672, 10.61036646226114,
      10.89849965960364,  10.203762898863669, 10.997560826267238,
      11.484217379438984, 11.792836176993665, 12.24310468755171,
      11.464858097919262, 12.212747017409377, 11.425595666074955,
      11.572048533398757, 12.742093965163013, 11.381874288645637,
      12.191870445817015, 11.683156920035426, 11.152442115262197,
      11.90303691580457,  11.653292787169159, 11.938615382266098,
      16.970641701570223, 16.853602280380002, 17.26240782594733,
      16.644655390108507, 17.14310889757499,  16.910935455445955,
      17.505678976959697, 17.213498225466388, 2.4162310293553024,
      3.494587244462329,  3.5258600986408344, 3.4959806589517095,
      3.098390886949687,  3.343454654302911,  3.588847442290287,
      4.14614790111827,   5.152948641990529,  7.433696808092598,
      9.716311684833672,
  };
  size_t table_size = sizeof kCostTable / sizeof *kCostTable;
  if (tok >= table_size) tok = table_size - 1;
  return kCostTable[tok] + nbits;
}

void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
                    const std::vector<std::vector<Token>>& tokens,
                    LZ77Params& lz77,
                    std::vector<std::vector<Token>>& tokens_lz77) {
  // TODO(veluca): tune heuristics here.
  SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
  float bit_decrease = 0;
  size_t total_symbols = 0;
  tokens_lz77.resize(tokens.size());
  HybridUintConfig uint_config;
  std::vector<float> sym_cost;
  for (size_t stream = 0; stream < tokens.size(); stream++) {
    size_t distance_multiplier =
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
    const auto& in = tokens[stream];
    auto& out = tokens_lz77[stream];
    total_symbols += in.size();
    // Cumulative sum of bit costs.
    sym_cost.resize(in.size() + 1);
    for (size_t i = 0; i < in.size(); i++) {
      uint32_t tok, nbits, unused_bits;
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
    }

    out.reserve(in.size());
    size_t max_distance = in.size();
    size_t min_length = lz77.min_length;
    JXL_DASSERT(min_length >= 3);
    size_t max_length = in.size();

    // Use next power of two as window size.
    size_t window_size = 1;
    while (window_size < max_distance && window_size < kWindowSize) {
      window_size <<= 1;
    }

    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
                    distance_multiplier);
    size_t len;
    size_t dist_symbol;

    const size_t max_lazy_match_len = 256;  // 0 to disable lazy matching

    // Whether the next symbol was already updated (to test lazy matching)
    bool already_updated = false;
    for (size_t i = 0; i < in.size(); i++) {
      out.push_back(in[i]);
      if (!already_updated) chain.Update(i);
      already_updated = false;
      chain.FindMatch(i, max_distance, &dist_symbol, &len);
      if (len >= min_length) {
        if (len < max_lazy_match_len && i + 1 < in.size()) {
          // Try length at next symbol lazy matching
          chain.Update(i + 1);
          already_updated = true;
          size_t len2, dist_symbol2;
          chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
          if (len2 > len) {
            // Use the lazy match. Add literal, and use the next length starting
            // from the next byte.
            ++i;
            already_updated = false;
            len = len2;
            dist_symbol = dist_symbol2;
            out.push_back(in[i]);
          }
        }

        float cost = sym_cost[i + len] - sym_cost[i];
        size_t lz77_len = len - lz77.min_length;
        float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
                          sce.AddSymbolCost(out.back().context);

        if (lz77_cost <= cost) {
          out.back().value = len - min_length;
          out.back().is_lz77_length = true;
          out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
          bit_decrease += cost - lz77_cost;
        } else {
          // LZ77 match ignored, and symbol already pushed. Push all other
          // symbols and skip.
          for (size_t j = 1; j < len; j++) {
            out.push_back(in[i + j]);
          }
        }

        if (already_updated) {
          chain.Update(i + 2, len - 2);
          already_updated = false;
        } else {
          chain.Update(i + 1, len - 1);
        }
        i += len - 1;
      } else {
        // Literal, already pushed
      }
    }
  }

  if (bit_decrease > total_symbols * 0.2 + 16) {
    lz77.enabled = true;
  }
}

void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
                       const std::vector<std::vector<Token>>& tokens,
                       LZ77Params& lz77,
                       std::vector<std::vector<Token>>& tokens_lz77) {
  std::vector<std::vector<Token>> tokens_for_cost_estimate;
  ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
  // If greedy-LZ77 does not give better compression than no-lz77, no reason to
  // run the optimal matching.
  if (!lz77.enabled) return;
  SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
                          tokens_for_cost_estimate, lz77);
  tokens_lz77.resize(tokens.size());
  HybridUintConfig uint_config;
  std::vector<float> sym_cost;
  std::vector<uint32_t> dist_symbols;
  for (size_t stream = 0; stream < tokens.size(); stream++) {
    size_t distance_multiplier =
        params.image_widths.size() > stream ? params.image_widths[stream] : 0;
    const auto& in = tokens[stream];
    auto& out = tokens_lz77[stream];
    // Cumulative sum of bit costs.
    sym_cost.resize(in.size() + 1);
    for (size_t i = 0; i < in.size(); i++) {
      uint32_t tok, nbits, unused_bits;
      uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
      sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
    }

    out.reserve(in.size());
    size_t max_distance = in.size();
    size_t min_length = lz77.min_length;
    JXL_DASSERT(min_length >= 3);
    size_t max_length = in.size();

    // Use next power of two as window size.
    size_t window_size = 1;
    while (window_size < max_distance && window_size < kWindowSize) {
      window_size <<= 1;
    }

    HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
                    distance_multiplier);

    struct MatchInfo {
      uint32_t len;
      uint32_t dist_symbol;
      uint32_t ctx;
      float total_cost = std::numeric_limits<float>::max();
    };
    // Total cost to encode the first N symbols.
    std::vector<MatchInfo> prefix_costs(in.size() + 1);
    prefix_costs[0].total_cost = 0;

    size_t rle_length = 0;
    size_t skip_lz77 = 0;
    for (size_t i = 0; i < in.size(); i++) {
      chain.Update(i);
      float lit_cost =
          prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
      if (prefix_costs[i + 1].total_cost > lit_cost) {
        prefix_costs[i + 1].dist_symbol = 0;
        prefix_costs[i + 1].len = 1;
        prefix_costs[i + 1].ctx = in[i].context;
        prefix_costs[i + 1].total_cost = lit_cost;
      }
      if (skip_lz77 > 0) {
        skip_lz77--;
        continue;
      }
      dist_symbols.clear();
      chain.FindMatches(i, max_distance,
                        [&dist_symbols](size_t len, size_t dist_symbol) {
                          if (dist_symbols.size() <= len) {
                            dist_symbols.resize(len + 1, dist_symbol);
                          }
                          if (dist_symbol < dist_symbols[len]) {
                            dist_symbols[len] = dist_symbol;
                          }
                        });
      if (dist_symbols.size() <= min_length) continue;
      {
        size_t best_cost = dist_symbols.back();
        for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
          if (dist_symbols[j] < best_cost) {
            best_cost = dist_symbols[j];
          }
          dist_symbols[j] = best_cost;
        }
      }
      for (size_t j = min_length; j < dist_symbols.size(); j++) {
        // Cost model that uses results from lazy LZ77.
        float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
                          sce.DistCost(dist_symbols[j], lz77);
        float cost = prefix_costs[i].total_cost + lz77_cost;
        if (prefix_costs[i + j].total_cost > cost) {
          prefix_costs[i + j].len = j;
          prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
          prefix_costs[i + j].ctx = in[i].context;
          prefix_costs[i + j].total_cost = cost;
        }
      }
      // We are in a RLE sequence: skip all the symbols except the first 8 and
      // the last 8. This avoid quadratic costs for sequences with long runs of
      // the same symbol.
      if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
          (dist_symbols.back() == 1 && distance_multiplier != 0)) {
        rle_length++;
      } else {
        rle_length = 0;
      }
      if (rle_length >= 8 && dist_symbols.size() > 9) {
        skip_lz77 = dist_symbols.size() - 10;
        rle_length = 0;
      }
    }
    size_t pos = in.size();
    while (pos > 0) {
      bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
      if (is_lz77_length) {
        size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
        out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
      }
      size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
                                  : in[pos - 1].value;
      out.emplace_back(prefix_costs[pos].ctx, val);
      out.back().is_lz77_length = is_lz77_length;
      pos -= prefix_costs[pos].len;
    }
    std::reverse(out.begin(), out.end());
  }
}

void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
               const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
               std::vector<std::vector<Token>>& tokens_lz77) {
  if (params.initialize_global_state) {
    lz77.enabled = false;
  }
  if (params.force_huffman) {
    lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
  } else {
    lz77.min_symbol = 224;
  }
  switch (params.lz77_method) {
    case HistogramParams::LZ77Method::kNone:
      return;
    case HistogramParams::LZ77Method::kRLE:
      ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
      return;
    case HistogramParams::LZ77Method::kLZ77:
      ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
      return;
    case HistogramParams::LZ77Method::kOptimal:
      ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
      return;
  }
}
}  // namespace

Status EncodeHistograms(const std::vector<uint8_t>& context_map,
                        const EntropyEncodingData& codes, BitWriter* writer,
                        LayerType layer, AuxOut* aux_out) {
  return writer->WithMaxBits(
      128 + kClustersLimit * 136, layer, aux_out,
      [&]() -> Status {
        JXL_RETURN_IF_ERROR(Bundle::Write(codes.lz77, writer, layer, aux_out));
        if (codes.lz77.enabled) {
          EncodeUintConfig(codes.lz77.length_uint_config, writer,
                           /*log_alpha_size=*/8);
        }
        JXL_RETURN_IF_ERROR(EncodeContextMap(
            context_map, codes.encoding_info.size(), writer, layer, aux_out));
        writer->Write(1, TO_JXL_BOOL(codes.use_prefix_code));
        size_t log_alpha_size = 8;
        if (codes.use_prefix_code) {
          log_alpha_size = PREFIX_MAX_BITS;
        } else {
          log_alpha_size = 8;  // streaming_mode
          writer->Write(2, log_alpha_size - 5);
        }
        EncodeUintConfigs(codes.uint_config, writer, log_alpha_size);
        if (codes.use_prefix_code) {
          for (const auto& info : codes.encoding_info) {
            StoreVarLenUint16(info.size() - 1, writer);
          }
        }
        for (const auto& histo_writer : codes.encoded_histograms) {
          JXL_RETURN_IF_ERROR(writer->AppendUnaligned(histo_writer));
        }
        return true;
      },
      /*finished_histogram=*/true);
}

StatusOr<size_t> BuildAndEncodeHistograms(
    JxlMemoryManager* memory_manager, const HistogramParams& params,
    size_t num_contexts, std::vector<std::vector<Token>>& tokens,
    EntropyEncodingData* codes, std::vector<uint8_t>* context_map,
    BitWriter* writer, LayerType layer, AuxOut* aux_out) {
  size_t cost = 0;
  codes->lz77.nonserialized_distance_context = num_contexts;
  std::vector<std::vector<Token>> tokens_lz77;
  ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
  if (ans_fuzzer_friendly_) {
    codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
    codes->lz77.min_symbol = 2048;
  }

  const size_t max_contexts = std::min(num_contexts, kClustersLimit);
  const auto& body = [&]() -> Status {
    if (writer) {
      JXL_RETURN_IF_ERROR(Bundle::Write(codes->lz77, writer, layer, aux_out));
    } else {
      size_t ebits, bits;
      JXL_RETURN_IF_ERROR(Bundle::CanEncode(codes->lz77, &ebits, &bits));
      cost += bits;
    }
    if (codes->lz77.enabled) {
      if (writer) {
        size_t b = writer->BitsWritten();
        EncodeUintConfig(codes->lz77.length_uint_config, writer,
                         /*log_alpha_size=*/8);
        cost += writer->BitsWritten() - b;
      } else {
        SizeWriter size_writer;
        EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
                         /*log_alpha_size=*/8);
        cost += size_writer.size;
      }
      num_contexts += 1;
      tokens = std::move(tokens_lz77);
    }
    size_t total_tokens = 0;
    // Build histograms.
    HistogramBuilder builder(num_contexts);
    HybridUintConfig uint_config;  //  Default config for clustering.
    // Unless we are using the kContextMap histogram option.
    if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
      uint_config = HybridUintConfig(2, 0, 1);
    }
    if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
      uint_config = HybridUintConfig(0, 0, 0);
    }
    if (ans_fuzzer_friendly_) {
      uint_config = HybridUintConfig(10, 0, 0);
    }
    for (const auto& stream : tokens) {
      if (codes->lz77.enabled) {
        for (const auto& token : stream) {
          total_tokens++;
          uint32_t tok, nbits, bits;
          (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
              .Encode(token.value, &tok, &nbits, &bits);
          tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
          builder.VisitSymbol(tok, token.context);
        }
      } else if (num_contexts == 1) {
        for (const auto& token : stream) {
          total_tokens++;
          uint32_t tok, nbits, bits;
          uint_config.Encode(token.value, &tok, &nbits, &bits);
          builder.VisitSymbol(tok, /*token.context=*/0);
        }
      } else {
        for (const auto& token : stream) {
          total_tokens++;
          uint32_t tok, nbits, bits;
          uint_config.Encode(token.value, &tok, &nbits, &bits);
          builder.VisitSymbol(tok, token.context);
        }
      }
    }

    if (params.add_missing_symbols) {
      for (size_t c = 0; c < num_contexts; ++c) {
        for (int symbol = 0; symbol < ANS_MAX_ALPHABET_SIZE; ++symbol) {
          builder.VisitSymbol(symbol, c);
        }
      }
    }

    if (params.initialize_global_state) {
      bool use_prefix_code =
          params.force_huffman || total_tokens < 100 ||
          params.clustering == HistogramParams::ClusteringType::kFastest ||
          ans_fuzzer_friendly_;
      if (!use_prefix_code) {
        bool all_singleton = true;
        for (size_t i = 0; i < num_contexts; i++) {
          if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
            all_singleton = false;
          }
        }
        if (all_singleton) {
          use_prefix_code = true;
        }
      }
      codes->use_prefix_code = use_prefix_code;
    }

    if (params.add_fixed_histograms) {
      // TODO(szabadka) Add more fixed histograms.
      // TODO(szabadka) Reduce alphabet size by choosing a non-default
      // uint_config.
      const size_t alphabet_size = ANS_MAX_ALPHABET_SIZE;
      const size_t log_alpha_size = 8;
      JXL_ENSURE(alphabet_size == 1u << log_alpha_size);
      static_assert(ANS_MAX_ALPHABET_SIZE <= ANS_TAB_SIZE);
      std::vector<int32_t> counts =
          CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
      codes->encoding_info.emplace_back();
      codes->encoding_info.back().resize(alphabet_size);
      codes->encoded_histograms.emplace_back(memory_manager);
      BitWriter* histo_writer = &codes->encoded_histograms.back();
      JXL_RETURN_IF_ERROR(histo_writer->WithMaxBits(
          256 + alphabet_size * 24, LayerType::Header, nullptr,
          [&]() -> Status {
            JXL_ASSIGN_OR_RETURN(
                size_t ans_cost,
                BuildAndStoreANSEncodingData(
                    memory_manager, params.ans_histogram_strategy,
                    counts.data(), alphabet_size, log_alpha_size,
                    codes->use_prefix_code, codes->encoding_info.back().data(),
                    histo_writer));
            (void)ans_cost;
            return true;
          }));
    }

    // Encode histograms.
    JXL_ASSIGN_OR_RETURN(
        size_t entropy_bits,
        builder.BuildAndStoreEntropyCodes(memory_manager, params, tokens, codes,
                                          context_map, writer, layer, aux_out));
    cost += entropy_bits;
    return true;
  };
  if (writer) {
    JXL_RETURN_IF_ERROR(writer->WithMaxBits(
        128 + num_contexts * 40 + max_contexts * 96, layer, aux_out, body,
        /*finished_histogram=*/true));
  } else {
    JXL_RETURN_IF_ERROR(body());
  }

  if (aux_out != nullptr) {
    aux_out->layer(layer).num_clustered_histograms +=
        codes->encoding_info.size();
  }
  return cost;
}

size_t WriteTokens(const std::vector<Token>& tokens,
                   const EntropyEncodingData& codes,
                   const std::vector<uint8_t>& context_map,
                   size_t context_offset, BitWriter* writer) {
  size_t num_extra_bits = 0;
  if (codes.use_prefix_code) {
    for (const auto& token : tokens) {
      uint32_t tok, nbits, bits;
      size_t histo = context_map[context_offset + token.context];
      (token.is_lz77_length ? codes.lz77.length_uint_config
                            : codes.uint_config[histo])
          .Encode(token.value, &tok, &nbits, &bits);
      tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
      // Combine two calls to the BitWriter. Equivalent to:
      // writer->Write(codes.encoding_info[histo][tok].depth,
      //               codes.encoding_info[histo][tok].bits);
      // writer->Write(nbits, bits);
      uint64_t data = codes.encoding_info[histo][tok].bits;
      data |= static_cast<uint64_t>(bits)
              << codes.encoding_info[histo][tok].depth;
      writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
      num_extra_bits += nbits;
    }
    return num_extra_bits;
  }
  std::vector<uint64_t> out;
  std::vector<uint8_t> out_nbits;
  out.reserve(tokens.size());
  out_nbits.reserve(tokens.size());
  uint64_t allbits = 0;
  size_t numallbits = 0;
--> --------------------

--> maximum size reached

--> --------------------

Messung V0.5
C=94 H=88 G=90

¤ Dauer der Verarbeitung: 0.18 Sekunden  ¤

*© Formatika GbR, Deutschland






Wurzel

Suchen

Beweissystem der NASA

Beweissystem Isabelle

NIST Cobol Testsuite

Cephes Mathematical Library

Wiener Entwicklungsmethode

Haftungshinweis

Die Informationen auf dieser Webseite wurden nach bestem Wissen sorgfältig zusammengestellt. Es wird jedoch weder Vollständigkeit, noch Richtigkeit, noch Qualität der bereit gestellten Informationen zugesichert.

Bemerkung:

Die farbliche Syntaxdarstellung und die Messung sind noch experimentell.