Kneser.hh

Go to the documentation of this file.
00001 #ifndef KNESER_HH
00002 #define KNESER_HH
00003 
00004 #include "bit/Array.hh"
00005 #include "bit/Trie.hh"
00006 #include "str/str.hh"
00007 #include "lm/SymbolMap.hh"
00008 #include "util/Progress.hh"
00009 #include "util/util.hh"
00010 
00011 namespace bit {
00012 
00020   class Kneser { 
00021   public:
00022   
00023     typedef Trie<Array> Trie;
00024     typedef Trie::Iterator Iterator;
00025     typedef SymbolMap<std::string, int> SymbolMap;
00026     typedef std::vector<float> FloatVec;
00027     typedef std::vector<int> IntVec;
00028     typedef std::vector<int> Ngram;
00029 
00031     Kneser()
00032     {
00033       m_d1_weight_model = 1;
00034       m_d1_model = 1;
00035 
00036       m_sentence_start_str = "<s>";
00037       m_sentence_end_str = "</s>";
00038       m_sentence_start_id = -1;
00039       m_sentence_end_id = -1;
00040       m_progress_skip = 53871;
00041       reserve_orders(10); // FIXME magic number
00042     }
00043 
00045     void set_d1_model(int model)
00046     {
00047       m_d1_model = model;
00048     }
00049 
00051     void set_d1_weight_model(int model)
00052     {
00053       m_d1_weight_model = model;
00054     }
00055 
00057     const SymbolMap &symbol_map() const
00058     {
00059       return m_symbol_map;
00060     }
00061 
00063     int sentence_start_id() const
00064     {
00065       return m_sentence_start_id;
00066     }
00067 
00069     int sentence_end_id() const
00070     {
00071       return m_sentence_end_id;
00072     }
00073 
00075     Iterator root() const
00076     {
00077       return Iterator(m_trie);
00078     }
00079 
00081     u64 num_ngrams() const
00082     {
00083       u64 ret = 0;
00084       for (size_t i = 0; i < m_counts.size(); i++)
00085         ret += m_counts[i].size();
00086       return ret;
00087     }
00088 
00090     u64 num_active_ngrams() const
00091     {
00092       u64 ret = 0;
00093       for (size_t i = 0; i < m_num_ngrams.size(); i++)
00094         ret += m_num_ngrams[i];
00095       return ret;
00096     }
00097 
00102     template <class T>
00103     Iterator find(const std::vector<T> &vec) const
00104     {
00105       return m_trie.find(vec);
00106     }
00107 
00112     Iterator find(const std::string &str) const
00113     {
00114       return m_trie.find(ngram(str));
00115     }
00116 
00123     float ngram_prob(Iterator it) const
00124     {
00125       double ret = 1;
00126       if (it.length() == 0)
00127         throw bit::invalid_argument("bit::Kneser::prob_weight() at root");
00128       while (it.length() > 0) {
00129         if (it.length() != 1 || it.symbol() != m_sentence_start_id) {
00130           switch (m_d1_model) {
00131           case 0:
00132             ret *= prob_full(it);
00133             break;
00134           case 1:
00135             ret *= prob_abs_full(it);
00136             break;
00137           case 2:
00138             throw bit::invalid_argument(
00139               "bit::Kneser::ngram_prob() invalid m_d1_model");
00140           }
00141         }
00142         it.goto_parent();
00143       }
00144       return ret;
00145     }
00146 
00153     float prob_beta_lower(Iterator it) const
00154     {
00155       assert(it.length() > 0);
00156 
00157       double ret = 0;
00158       double scale = 1;
00159       u32 symbol = it.symbol();
00160       it.goto_parent();
00161       scale *= get_beta_interpolation_numerator(it) /
00162         get_beta_denominator(it);
00163 
00164       while (!it.is_root()) {
00165         it.goto_backoff_full();
00166         if (it.goto_child(symbol)) {
00167           double numerator = 0;
00168           if (!is_pruned(it)) {
00169             numerator = get_beta_numerator(it) - 
00170               get_beta_discount(it.length() - 1);
00171             it.goto_parent();
00172             ret += scale * numerator / get_beta_denominator(it);
00173           }
00174           else
00175             it.goto_parent();
00176         }
00177 
00178         scale *= get_beta_interpolation_numerator(it) /
00179           get_beta_denominator(it);
00180       }
00181 
00182       // Zero-gram probability.  Sentence start is not counted as a
00183       // real symbol, thus "minus one"
00184       //
00185       ret += scale / (m_symbol_map.size() - 1);
00186       
00187       return ret;
00188     }
00189 
00191     Ngram ngram(const std::string &str) const
00192     {
00193       std::vector<std::string> symbols = str::split(str, " \t", true);
00194       Ngram ngram(symbols.size());
00195       for (size_t i = 0; i < symbols.size(); i++)
00196         ngram[i] = m_symbol_map.index(symbols[i]);
00197       return ngram;
00198     }
00199 
00205     float prob_beta_full(const Iterator &it) const
00206     {
00207       assert(m_sentence_start_id >= 0);
00208       assert(it.length() > 0);
00209       if (it.symbol() == m_sentence_start_id)
00210         throw bit::invalid_argument(
00211           "bit::Kneser::prob_beta_full(): called for sentence start");
00212 
00213       unsigned int len = it.length();
00214       double numerator = 0;
00215       if (!is_pruned(it))
00216         numerator = get_beta_numerator(it) - 
00217           get_beta_discount(len - 1);
00218       double denominator = get_beta_denominator(it.parent());
00219       return numerator / denominator + prob_beta_lower(it);
00220     }
00221 
00230     float prob_beta_full(Ngram ngram) const
00231     {
00232       if (ngram.empty())
00233         throw bit::invalid_argument(
00234           "bit::Kneser::prob_beta_full(Ngram): ngram empty");
00235       int symbol = ngram.back();
00236       if (symbol == m_sentence_start_id)
00237         throw bit::invalid_argument(
00238           "bit::Kneser::prob_beta_full(Ngram) sentence start symbol");
00239         
00240       double scale = 1;
00241       double ret = 0;
00242       while (1) {
00243         assert(ngram.size() > 0);
00244         Iterator it = find(ngram);
00245 
00246         // Ngram found?
00247         if (!it.is_root()) {
00248           Iterator parent = it.parent();
00249           double denominator = get_beta_denominator(parent);
00250           if (!is_pruned(it)) {
00251             double numerator = get_beta_numerator(it) - 
00252               get_beta_discount(it.length() - 1);
00253             assert(numerator >= 0);
00254             ret += scale * numerator / denominator;
00255           }
00256           scale *= get_beta_interpolation_numerator(parent) / denominator;
00257         }
00258 
00259         // Ngram not found?
00260         else {
00261           if (it.length() == 1)
00262             scale *= get_beta_interpolation_numerator(root()) /
00263               get_beta_denominator(root());
00264           else {
00265             Ngram parent_ngram(ngram.begin(), ngram.end() - 1);
00266             Iterator parent_it = find(parent_ngram);
00267             if (!parent_it.is_root())
00268               scale *= get_beta_interpolation_numerator(parent_it) /
00269                 get_beta_denominator(parent_it);
00270           }
00271         }
00272 
00273         if (it.length() == 1)
00274           break;
00275 
00276         ngram.erase(ngram.begin());
00277       }
00278 
00279       ret += scale / (m_symbol_map.size() - 1);
00280 
00281       return ret;
00282     }
00283 
00289     float prob_lower(Iterator it) const
00290     {
00291       assert(it.length() > 0);
00292 
00293       double ret = 0;
00294       double scale = 1;
00295       u32 symbol = it.symbol();
00296       it.goto_parent();
00297       scale *= interpolation(it);
00298 
00299       while (!it.is_root()) {
00300         it.goto_backoff_full();
00301         if (it.goto_child(symbol)) {
00302           double numerator = 
00303             sum_nonzero_xg(it) - get_discount(it.length() - 1);
00304           it.goto_parent();
00305           double denominator = sum_nonzero_xgx(it);
00306           ret += scale * numerator / denominator;
00307         }
00308 
00309         scale *= interpolation(it);
00310       }
00311 
00312       // Zero-gram probability.  Sentence start is not counted as a
00313       // real symbol, thus "minus one"
00314       //
00315       ret += scale / (m_symbol_map.size() - 1);
00316       
00317       return ret;
00318     }
00319 
00326     float prob_full(const Iterator &it, float *lower_prob = NULL) const
00327     {
00328       assert(m_sentence_start_id >= 0);
00329       assert(it.length() > 0);
00330       if (it.symbol() == m_sentence_start_id)
00331         throw bit::invalid_argument(
00332           "bit::Kneser::prob_full_ikn(): called for sentence start");
00333 
00334       unsigned int len = it.length();
00335       double numerator = sum_nonzero_xg(it) - get_discount(len - 1);
00336       double denominator = sum_nonzero_xgx(it.parent());
00337       double lower = prob_lower(it);
00338       if (lower_prob != NULL)
00339         *lower_prob = lower;
00340       return numerator / denominator + lower;
00341     }
00342 
00349     float prob_abs_lower(Iterator it) const
00350     {
00351       assert(it.length() > 0);
00352 
00353       double ret = 0;
00354       double scale = 1;
00355       u32 symbol = it.symbol();
00356       it.goto_parent();
00357       scale *= interpolation_abs(it);
00358 
00359       while (!it.is_root()) {
00360         it.goto_backoff_full();
00361         if (it.goto_child(symbol)) {
00362           double numerator = 
00363             get_count(it) - get_discount(it.length() - 1);
00364           it.goto_parent();
00365           double denominator = sum_gx(it);
00366           ret += scale * numerator / denominator;
00367         }
00368 
00369         scale *= interpolation_abs(it);
00370       }
00371 
00372       // Zero-gram probability.  Sentence start is not counted as a
00373       // real symbol, thus "minus one"
00374       //
00375       ret += scale / (m_symbol_map.size() - 1);
00376       
00377       return ret;
00378     }
00379 
00386     float prob_abs_full(const Iterator &it, float *lower_prob = NULL) const
00387     {
00388       assert(m_sentence_start_id >= 0);
00389       assert(it.length() > 0);
00390       if (it.symbol() == m_sentence_start_id)
00391         throw bit::invalid_argument(
00392           "bit::Kneser::prob_abs_full(): called for sentence start");
00393 
00394       unsigned int len = it.length();
00395       double numerator = get_count(it) - get_discount(len - 1);
00396       double denominator = sum_gx(it.parent());
00397       double lower = prob_abs_lower(it);
00398       if (lower_prob != NULL)
00399         *lower_prob = lower;
00400       return numerator / denominator + lower;
00401     }
00402 
00404     bool is_pruned(const Iterator &it) const
00405     {
00406       return get_value(m_pruned, it) > 0;
00407     }
00408 
00415     u32 get_count(const Iterator &it) const
00416     {
00417       return get_value(m_counts, it);
00418     }
00419 
00420     u32 sum_gx(const Iterator &it) const 
00421     {
00422       if (it.is_root())
00423         return m_sum_gx0;
00424       return get_value(m_sum_gx, it);
00425     }
00426 
00427     u32 sum_nonzero_xg(const Iterator &it) const
00428     {
00429       return get_value(m_sum_nonzero_xg, it);
00430     }
00431 
00432     u32 sum_nonzero_gx(const Iterator &it) const
00433     {
00434       if (it.is_root())
00435         return m_sum_nonzero_gx0;
00436       return get_value(m_sum_nonzero_gx, it);
00437 
00438 //       u32 ret = 0;
00439 //       if (it.goto_first_child()) {
00440 //         do {
00441 //           if (get_count(it) > 0)
00442 //             ret++;
00443 //         } while (it.goto_next_sibling());
00444 //       }
00445 //       return ret;
00446     }
00447 
00448     u32 sum_nonzero_xgx(const Iterator &it) const
00449     {
00450       if (it.is_root())
00451         return m_sum_nonzero_xgx0;
00452       return get_value(m_sum_nonzero_xgx, it);
00453 
00454 //       u32 ret = 0;
00455 //       if (it.goto_first_child()) {
00456 //         do {
00457 //           ret += sum_nonzero_xg(it);
00458 //         } while (it.goto_next_sibling());
00459 //       }
00460 //       return ret;
00461     }
00462 
00463     float get_beta_numerator(const Iterator &it) const
00464     {
00465       if (is_pruned(it))
00466         return 0;
00467 
00468       int term1 = get_count(it) - get_value(m_sum_xg_not_pruned, it);
00469 
00470       float term2;
00471       if (1) {
00472         term2 = get_beta_discount(it.length()) * 
00473           get_value(m_sum_nonzero_xg_not_pruned, it);
00474       }
00475       else {
00476         term2 = get_value(m_sum_nonzero_xg_not_pruned, it);
00477         static bool printed = false;
00478         if (!printed) {
00479           fprintf(stderr, "\n"
00480                   "**************************************************\n"
00481                   "* WARNING: using d = 1 in beta numerator\n"
00482                   "*\n");
00483         }
00484         printed = true;
00485       }
00486       
00487       if (term1 < 0)
00488         throw std::out_of_range(
00489           str::fmt(256, "N(h,w) - Sum_v N(v,h,w) negative: %d\n", term1));
00490 
00491       return term1 + term2;
00492     }
00493 
00495     float get_beta_denominator(const Iterator &it) const
00496     {
00497       if (it.is_root())
00498         return m_beta_denominator0;
00499       if (it.length() > m_beta_denominator.size())
00500         return 1;
00501       return get_value(m_beta_denominator, it);
00502     }
00503 
00504     float get_beta_interpolation_numerator(const Iterator &it) const
00505     {
00506       if (it.is_root())
00507         return m_beta_interpolation_numerator0;
00508       if (it.length() > m_beta_interpolation_numerator.size())
00509         return 1;
00510       return get_value(m_beta_interpolation_numerator, it);
00511     }
00512 
00516     float get_d1(const Iterator &it) const
00517     {
00518       if (it.length() < 2)
00519         throw bit::invalid_argument("bit::Kneser::get_d1(): too low order");
00520       return get_value(m_d1, it);
00521     }
00522 
00526     float get_d2(const Iterator &it) const
00527     {
00528       if (it.length() < 2)
00529         throw bit::invalid_argument("bit::Kneser::get_d2(): too low order");
00530       return get_value(m_d2, it) / get_value(m_d2_norm, it);
00531     }
00532 
00534     int num_active_children(Iterator it) const
00535     {
00536       int ret = 0;
00537       if (it.goto_first_child()) {
00538         do {
00539           if (!is_pruned(it) && it.symbol() != m_sentence_start_id)
00540             ret++;
00541         } while (it.goto_next_sibling());
00542       }
00543       return ret;
00544     }
00545 
00550     template <class T>
00551     std::string 
00552     ngram_str(const std::vector<T> &ngram) const
00553     {
00554       std::string str;
00555       for (size_t o = 0; o < ngram.size(); o++) {
00556         if (o != 0)
00557           str.append(" ");
00558         str.append(m_symbol_map.at(ngram[o]));
00559       }
00560       return str;
00561     }
00562 
00565     void write_binary_counts(FILE *file) const
00566     {
00567       m_symbol_map.write(file);
00568       m_trie.write(file);
00569       util::write(file, m_counts);
00570       util::write(file, m_sum_nonzero_xg);
00571       util::write(file, m_sum_nonzero_xgx);
00572       util::write(file, m_sum_nonzero_gx);
00573       util::write(file, m_sum_gx);
00574       fwrite(&m_sum_nonzero_xgx0, sizeof(m_sum_nonzero_xgx0), 1, file);
00575       fwrite(&m_sum_nonzero_gx0, sizeof(m_sum_nonzero_gx0), 1, file);
00576       fwrite(&m_sum_gx0, sizeof(m_sum_gx0), 1, file);
00577     }
00578 
00581     void read_binary_counts(FILE *file) 
00582     {
00583       m_symbol_map.read(file);
00584       m_trie.read(file);
00585       util::read(file, m_counts);
00586       util::read(file, m_sum_nonzero_xg);
00587       util::read(file, m_sum_nonzero_xgx);
00588       util::read(file, m_sum_nonzero_gx);
00589       util::read(file, m_sum_gx);
00590       size_t ret = 0;
00591       ret += fread(&m_sum_nonzero_xgx0, sizeof(m_sum_nonzero_xgx0), 1, file);
00592       ret += fread(&m_sum_nonzero_gx0, sizeof(m_sum_nonzero_gx0), 1, file);
00593       ret += fread(&m_sum_gx0, sizeof(m_sum_gx0), 1, file);
00594       if (ret != 3)
00595         throw io_error("bit::Kneser::read_binary_counts() failed");
00596 
00597       m_sentence_start_id = m_symbol_map.index(m_sentence_start_str);
00598       m_sentence_end_id = m_symbol_map.index(m_sentence_end_str);
00599       m_num_ngrams.resize(m_counts.size());
00600       for (size_t i = 0; i < m_counts.size(); i++)
00601         m_num_ngrams.at(i) = m_counts[i].size();
00602     }
00603 
00605     void write_binary_d1d2(FILE *file) const
00606     {
00607       util::write(file, m_d1);
00608       util::write(file, m_d2);
00609       util::write(file, m_d2_norm);
00610     }
00611 
00613     void read_binary_d1d2(FILE *file) 
00614     {
00615       util::read(file, m_d1);
00616       util::read(file, m_d2);
00617       util::read(file, m_d2_norm);
00618     }
00619 
00621     void write_arpa(FILE *file) const 
00622     {
00623       fprintf(file, "\\data\\\n");
00624 //      for (size_t l = 0; l < order(); l++)
00625 //        fprintf(file, "ngram %d=%lld\n", (int)l+1, 
00626 //                m_score_arrays.at(l).num_elems());
00627 
00628       for (size_t o = 0; o < m_counts.size(); o++) {
00629         fprintf(file, "\n\\%zd-grams:\n", o+1);
00630         Iterator it = root();
00631         while (it.goto_next_on_level(o)) {
00632 
00633           float log_prob;
00634           if (it.length() == 1 && it.symbol() == m_sentence_start_id)
00635             log_prob = -99;
00636           else
00637             log_prob = log10(prob_full(it));
00638 
00639           fprintf(file, "%g\t%s", log_prob,
00640                   ngram_str(it.symbol_vec()).c_str());
00641           if (it.num_children() > 0)
00642             fprintf(file, "\t%g\n", log10(interpolation(it)));
00643           else
00644             fputs("\n", file);
00645         }
00646       }
00647       fprintf(file, "\n\\end\\\n");
00648     }
00649 
00651     void write_beta_arpa(FILE *file) const 
00652     {
00653       fprintf(file, "\\data\\\n");
00654 //      for (size_t l = 0; l < order(); l++)
00655 //        fprintf(file, "ngram %d=%lld\n", (int)l+1, 
00656 //                m_score_arrays.at(l).num_elems());
00657 
00658       for (size_t i = 0; i < m_num_ngrams.size(); i++) {
00659         assert(m_num_ngrams[i] >= 0);
00660         if (m_num_ngrams[i] == 0)
00661           break;
00662         fprintf(file, "ngram %d=%d\n", (int)i + 1, m_num_ngrams[i]);
00663       }
00664 
00665       Progress p(m_progress_skip, num_ngrams());
00666       p.set_report_string("writing arpa:");
00667 
00668       for (size_t o = 0; o < m_counts.size(); o++) {
00669         fprintf(file, "\n\\%zd-grams:\n", o+1);
00670         Iterator it = root();
00671         while (it.goto_next_on_level(o)) {
00672           p.step();
00673           if (get_value(m_pruned, it) > 0)
00674             continue;
00675 
00676           float log_prob;
00677           if (it.length() == 1 && it.symbol() == m_sentence_start_id)
00678             log_prob = -99;
00679           else
00680             log_prob = log10(prob_beta_full(it));
00681 
00682           fprintf(file, "%g\t%s", log_prob,
00683                   ngram_str(it.symbol_vec()).c_str());
00684           if (num_active_children(it) > 0)
00685             fprintf(file, "\t%g\n", 
00686                     log10(get_beta_interpolation_numerator(it) /
00687                           get_beta_denominator(it)));
00688           else
00689             fputs("\n", file);
00690         }
00691       }
00692       fprintf(file, "\n\\end\\\n");
00693       p.finish();
00694     }
00695 
00699     void reserve_orders(unsigned int orders)
00700     {
00701       m_counts.reserve(orders);
00702       m_sum_nonzero_xg.reserve(orders);
00703       m_sum_nonzero_gx.reserve(orders);
00704       m_sum_nonzero_xgx.reserve(orders);
00705       m_pruned.reserve(orders);
00706       m_d1.reserve(orders);
00707       m_d2.reserve(orders);
00708       m_d2_norm.reserve(orders);
00709       m_sum_xg_not_pruned.reserve(orders);
00710       m_sum_nonzero_xg_not_pruned.reserve(orders);
00711       m_beta_denominator.reserve(orders);
00712       m_beta_interpolation_numerator.reserve(orders);
00713       m_trie.reserve_levels(orders);
00714     }
00715     
00716 
00725     void read_counts(FILE *file, bool integer_symbols = false)
00726     {
00727       std::string line;
00728       std::vector<int> ngram;
00729       std::vector<std::string> fields;
00730       int count;
00731       Progress p(m_progress_skip);
00732       p.set_report_string("reading counts:");
00733       while (str::read_line(line, file, true)) {
00734         try {
00735           if (integer_symbols) {
00736             ngram = str::long_vec<int>(line);
00737             if (ngram.size() < 2)
00738               throw std::exception();
00739             count = ngram.back();
00740             ngram.pop_back();
00741           }
00742           else {
00743             fields = str::split(line, " \t", true);
00744             if (fields.size() < 2)
00745               throw std::exception();
00746             count = str::str2long(fields.back());
00747             fields.pop_back();
00748             ngram.resize(fields.size());
00749             for (size_t i = 0; i < fields.size(); i++)
00750               ngram[i] = m_symbol_map.insert(fields[i]);
00751           }
00752         }
00753         catch (std::exception &e) {
00754           throw bit::io_error(
00755             std::string("bit::Kneser::read_counts(): invalid line: ") + line);
00756         }
00757         
00758         add(ngram, count);
00759         p.step();
00760       }
00761       p.finish();
00762 
00763       m_num_ngrams.resize(m_counts.size());
00764       for (size_t i = 0; i < m_counts.size(); i++)
00765         m_num_ngrams.at(i) = m_counts[i].size();
00766     }
00767 
00773     void compute_sums()
00774     {
00775       assert(m_sentence_start_id < 0);
00776       m_sentence_start_id = m_symbol_map.index(m_sentence_start_str);
00777       m_sentence_end_id = m_symbol_map.index(m_sentence_end_str);
00778       {
00779         Iterator it = root();
00780         it.goto_child(m_sentence_start_id);
00781         set_count(it, 0);
00782         fprintf(stderr, "WARNING: setting count(%s) = 0\n", 
00783                 m_sentence_start_str.c_str());
00784       }
00785 
00786       assert(m_sum_nonzero_xg.empty());
00787       assert(m_sum_nonzero_xgx.empty());
00788       m_sum_gx.resize(m_counts.size() - 1);
00789       m_sum_nonzero_xg.resize(m_counts.size());
00790       m_sum_nonzero_xgx.resize(m_counts.size() - 1);
00791       m_sum_nonzero_gx.resize(m_counts.size() - 1);
00792       for (size_t i = 0; i < m_sum_nonzero_xg.size(); i++)
00793         m_sum_nonzero_xg[i].resize(m_counts[i].size());
00794       for (size_t i = 0; i < m_sum_nonzero_xgx.size(); i++) {
00795         m_sum_gx[i].resize(m_counts[i].size());
00796         m_sum_nonzero_xgx[i].resize(m_counts[i].size());
00797         m_sum_nonzero_gx[i].resize(m_counts[i].size());
00798       }
00799       m_sum_nonzero_xgx0 = 0;
00800       m_sum_nonzero_gx0 = 0;
00801       m_sum_gx0 = 0;
00802 
00803       // Compute modified counts
00804       //
00805       {
00806         Progress p(m_progress_skip, num_ngrams());
00807         p.set_report_string("modified counts:");
00808         Iterator it = root();
00809         while (it.goto_next_depth_first()) {
00810           p.step();
00811 
00812           // sum_gx is needed only if absolute discounting is used
00813           if (m_d1_model == 1 || m_d1_weight_model == 1) {
00814             Iterator parent(it.parent());
00815             u32 count = get_count(it);
00816             if (parent.length() == 0)
00817               m_sum_gx0 += count;
00818             else
00819               add_value(m_sum_gx, parent, count);
00820           }
00821 
00822           if (it.length() < 2)
00823             continue;
00824         
00825           Iterator bo_it(it);
00826           if (bo_it.goto_backoff_once())
00827             add_value(m_sum_nonzero_xg, bo_it, 1);
00828         }
00829         p.finish();
00830       }
00831 
00832       // Use original counts for ngrams that do not have contexts
00833       // (ngrams starting with <s> for example)
00834       //
00835       {
00836         Progress p(m_progress_skip, num_ngrams());
00837         p.set_report_string("without context:");
00838         for (size_t o = 0; o < m_counts.size(); o++) {
00839           IntVec &src_array = m_counts.at(o);
00840           IntVec &sum_nonzero_xg = m_sum_nonzero_xg.at(o);
00841           for (u64 i = 0; i < src_array.size(); i++) {
00842             p.step();
00843             if (sum_nonzero_xg.at(i) == 0)
00844               sum_nonzero_xg[i] = src_array.at(i);
00845           }
00846         }
00847         p.finish();
00848       }
00849 
00850       // Compute sum_nonzero_xgx and sum_nonzero_gx
00851       //
00852       {
00853         Progress p(m_progress_skip, num_ngrams());
00854         p.set_report_string("modified sum counts:");
00855         Iterator it = root();
00856         while (it.goto_next_depth_first()) {
00857           p.step();
00858           if (it.symbol() == m_sentence_start_id)
00859             continue;
00860           u32 value = sum_nonzero_xg(it);
00861           Iterator parent = it.parent();
00862           if (parent.is_root()) {
00863             m_sum_nonzero_xgx0 += value;
00864             m_sum_nonzero_gx0++;
00865           }
00866           else {
00867             add_value(m_sum_nonzero_xgx, parent, value);
00868             add_value(m_sum_nonzero_gx, parent, 1);
00869           }
00870         }
00871         p.finish();
00872       }
00873     }
00874 
00876     void compute_d1()
00877     {
00878       float max_d1 = -1e30;
00879       float min_d1 = 1e30;
00880 
00881       assert(m_d1.empty());
00882       m_d1.resize(m_counts.size());
00883       for (size_t i = 0; i < m_counts.size(); i++)
00884         m_d1.at(i).resize(m_counts.at(i).size());
00885 
00886       Iterator it = root();
00887 
00888       Progress p(m_progress_skip, num_ngrams());
00889       p.set_report_string("computing d1:");
00890       while (it.goto_next_depth_first()) {
00891         p.step();
00892         if (it.length() < 2)
00893           continue;
00894 
00895         double orig = 0;
00896         float lower_prob = 0;
00897         if (m_d1_model == 0) {
00898           orig = prob_full(it, &lower_prob);
00899         }
00900         else if (m_d1_model == 1) {
00901           orig = prob_abs_full(it, &lower_prob);
00902         }
00903         else
00904           throw bit::invalid_argument(
00905             "bit::Kneser::compute_d1() invalid m_d1_model");
00906 
00907         double d1 = ngram_prob(it) * log10(orig / lower_prob);
00908 
00909         if (0) {
00910           fprintf(stderr, "d1: %12g %s\n", d1, 
00911                   ngram_str(it.symbol_vec()).c_str());
00912         }
00913 
00914         if (!(d1 > 0 && d1 < 1e10)) {
00915           fprintf(stderr, "WARNING: d1 = %g for %s\n", d1, 
00916                   ngram_str(it.symbol_vec()).c_str());
00917         }
00918 
00919         set_value(m_d1, it, d1);
00920         if (d1 < min_d1)
00921           min_d1 = d1;
00922         if (d1 > max_d1)
00923           max_d1 = d1;
00924       }
00925       p.finish();
00926 
00927       fprintf(stderr, "min_d1 = %g\n", min_d1);
00928       fprintf(stderr, "max_d1 = %g\n", max_d1);
00929     }
00930 
00933     void compute_d2_full()
00934     {
00935       assert(m_d2.empty());
00936       assert(!m_d1.empty());
00937       assert(m_d2_norm.empty());
00938       m_d2.resize(m_counts.size());
00939       m_d2_norm.resize(m_counts.size());
00940       for (size_t i = 0; i < m_counts.size(); i++) {
00941         m_d2.at(i).resize(m_counts.at(i).size());
00942         m_d2_norm.at(i).resize(m_counts.at(i).size());
00943       }
00944 
00945       // Compute d2 measures bottom-up fashion
00946       //
00947       Iterator it = root();
00948       Progress p(m_progress_skip, num_ngrams());
00949       p.set_report_string("full d2:");
00950       while (it.goto_next_depth_first_post()) {
00951         p.step();
00952         if (it.length() < 2)
00953           continue;
00954 
00955         D2Norm pair(get_d1(it), 1);
00956         Iterator child_it(it);
00957         if (child_it.goto_first_child()) {
00958           do {
00959             float d2 = get_value(m_d2, child_it);
00960             u32 norm = get_value(m_d2_norm, child_it);
00961             pair.add(d2, norm);
00962           } while (child_it.goto_next_sibling());
00963         }
00964 
00965         set_value(m_d2, it, pair.d2);
00966         set_value(m_d2_norm, it, pair.norm);
00967 
00968         if (0) {
00969           fprintf(stderr, "d2: %12g %s\n", pair.d2 / pair.norm, 
00970                   ngram_str(it.symbol_vec()).c_str());
00971         }
00972       }
00973       p.finish();
00974     }
00975 
00980     void compute_d2_trick()
00981     {
00982       fprintf(stderr, "WARNING: using erroneous d2 measure\n");
00983 
00984       assert(m_d2.empty());
00985       assert(!m_d1.empty());
00986       assert(m_d2_norm.empty());
00987       m_d2.resize(m_counts.size());
00988       m_d2_norm.resize(m_counts.size());
00989       for (size_t i = 0; i < m_counts.size(); i++)
00990         m_d2.at(i).resize(m_counts.at(i).size());
00991 
00992       // Compute d2 measures bottom-up fashion
00993       //
00994       std::vector<D2Norm> child_pairs;
00995       Iterator it = root();
00996       Progress p(m_progress_skip, num_ngrams());
00997       p.set_report_string("computing d2:");
00998       while (it.goto_next_depth_first_post()) {
00999         p.step();
01000         if (it.length() < 2)
01001           continue;
01002 
01003         D2Norm pair(get_d1(it), 1);
01004         Iterator child_it(it);
01005         child_pairs.clear();
01006         if (child_it.goto_first_child()) {
01007           do {
01008             float d2 = get_value(m_d2, child_it);
01009             u32 norm = get_value(m_d2_norm, child_it);
01010             pair.add(d2, norm);
01011             child_pairs.push_back(D2Norm(d2, norm));
01012           } while (child_it.goto_next_sibling());
01013           std::sort(child_pairs.begin(), child_pairs.end());
01014         }
01015 
01016         // Virtually remove childrens until the parent's d2 is smaller
01017         // than that of best children.
01018         //
01019         for (size_t i = 0; i < child_pairs.size(); i++) {
01020           if (pair < child_pairs[i])
01021             break;
01022           pair.add(-child_pairs[i].d2, -child_pairs[i].norm);
01023         }
01024 
01025         set_value(m_d2, it, pair.d2);
01026         set_value(m_d2_norm, it, pair.norm);
01027       }
01028       p.finish();
01029     }
01030 
01037     void prune_ngram(Iterator it)
01038     {
01039       if (it.length() < 2)
01040         throw bit::invalid_argument(
01041           "bit::Kneser::prune_ngram() ngram shorter than 2-gram");
01042         
01043       if (get_value(m_pruned, it) > 0)
01044         throw bit::invalid_argument(
01045           "bit::Kneser::prune_ngram() ngram pruned already");
01046       set_value(m_pruned, it, 1);
01047       m_num_ngrams.at(it.length() - 1)--;
01048       
01049       float d2 = get_value(m_d2, it);
01050       int d2_norm = get_value(m_d2_norm, it);
01051 
01052       if (0) {
01053         fprintf(stderr, "pruned %12g %12g / %d %s\n", 
01054                 get_value(m_d1, it), 
01055                 d2, d2_norm, 
01056                 ngram_str(it.symbol_vec()).c_str());
01057       }
01058 
01059       // Mark children pruned too
01060       {
01061         Iterator child_it(it);
01062         unsigned int len = it.length();
01063         while (child_it.goto_next_depth_first()) {
01064           if (child_it.length() <= len)
01065             break;
01066           if (get_value(m_pruned, child_it) > 0)
01067             continue;
01068           if (0) {
01069             fprintf(stderr, "pruned %12g %12g / %d %s\n", 
01070                     get_value(m_d1, child_it),
01071                     get_value(m_d2, child_it),
01072                     get_value(m_d2_norm, child_it),
01073                     ngram_str(child_it.symbol_vec()).c_str());
01074           }
01075           set_value(m_pruned, child_it, 1);
01076           m_num_ngrams.at(child_it.length() - 1)--;
01077         }
01078       }
01079 
01080       // Modify parents' d2 measure
01081       while (1) {
01082         it.goto_parent();
01083         if (it.length() == 1)
01084           break;
01085         add_value(m_d2, it, -d2);
01086         add_value(m_d2_norm, it, -d2_norm);
01087       }
01088     }
01089 
01093     void prune_threshold(float threshold)
01094     {
01095       assert(!m_d1.empty());
01096       assert(!m_d2.empty());
01097       assert(m_pruned.empty());
01098       m_pruned.resize(m_counts.size());
01099 
01100       Progress p(m_progress_skip, num_ngrams());
01101       p.set_report_string("pruning:");
01102       Iterator it = root();
01103       while (it.goto_next_depth_first_post()) {
01104         p.step();
01105         if (it.length() < 2)
01106           continue;
01107         if (get_d2(it) < threshold)
01108           prune_ngram(it);
01109       }
01110       p.finish();
01111       fprintf(stderr, "%lld ngrams left\n", num_active_ngrams());
01112     }
01113 
01122     void prune(unsigned int ngrams)
01123     {
01124       assert(!m_d1.empty());
01125       assert(!m_d2.empty());
01126       assert(m_pruned.empty());
01127       m_pruned.resize(m_counts.size());
01128 
01129       size_t num_ngrams = 0;
01130       for (size_t i = 0; i < m_counts.size(); i++)
01131         num_ngrams += m_counts.at(i).size();
01132       std::vector<OrderIndex> vec;
01133       vec.reserve(num_ngrams);
01134 
01135       Iterator it = root();
01136       while (it.goto_next_depth_first_post()) {
01137         assert(it.length() > 0);
01138         if (it.length() < 2)
01139           continue;
01140         vec.push_back(OrderIndex(it.length() - 1, it.symbol_index()));
01141       }
01142 
01143       if (ngrams > vec.size())
01144         throw bit::invalid_argument(
01145           "bit::Kneser::prune() trying to prune too many ngrams");
01146 
01147       Progress p(0, 2);
01148       p.set_report_string("sorting:");
01149       p.step();
01150       std::partial_sort(vec.begin(), vec.begin() + ngrams, vec.end(), 
01151                         PruneCompare(this));
01152       p.step();
01153       p.finish();
01154 
01155       for (size_t i = 0; i < ngrams; i++) {
01156         m_num_ngrams.at(vec[i].order)--;
01157         m_pruned.at(vec[i].order).set_grow_widen(vec[i].index, 1);
01158       }
01159 
01160 //       fprintf(stderr, "DEBUG: pruned\n");
01161 //       it = root();
01162 //       while (it.goto_next_depth_first())
01163 //         if (get_value(m_pruned, it) > 0)
01164 //           fprintf(stderr, "  %10g %s\n", get_d2(it), 
01165 //                   ngram_str(it.symbol_vec()).c_str());
01166 
01167       // Prune children of the pruned ngrams if they are not pruned already
01168       //
01169       it = root();
01170       while (it.goto_next_depth_first()) {
01171         if (it.length() < 2)
01172           continue;
01173         if (get_value(m_pruned, it.parent()) > 0 &&
01174             get_value(m_pruned, it) == 0)
01175         {
01176           unsigned int order = it.length() - 1;
01177           m_num_ngrams.at(order)--;
01178           set_value(m_pruned, it, 1);
01179 //           fprintf(stderr, 
01180 //                   "WARNING: pruning \"%s\" because parent was pruned\n",
01181 //                   ngram_str(it.symbol_vec()).c_str());
01182         }
01183       }
01184 
01185     }
01186 
01190     void compute_beta_numerator_terms()
01191     {
01192       assert(!m_discounts.empty());
01193       m_beta_discounts = m_discounts;
01194       for (size_t i = m_beta_discounts.size() - 1; i > 0; i--)
01195         m_beta_discounts[i - 1] *= m_beta_discounts[i];
01196 
01197       assert(m_sentence_start_id >= 0);
01198       assert(m_sum_xg_not_pruned.empty());
01199       assert(m_sum_nonzero_xg_not_pruned.empty());
01200       m_sum_xg_not_pruned.resize(m_counts.size() - 1);
01201       m_sum_nonzero_xg_not_pruned.resize(m_counts.size() - 1);
01202       for (size_t i = 0; i < m_counts.size() - 1; i++) {
01203         m_sum_xg_not_pruned.at(i).resize(m_counts.at(i).size());
01204         m_sum_nonzero_xg_not_pruned.at(i).resize(m_counts.at(i).size());
01205       }
01206 
01207       Iterator it = root();
01208       Progress p(m_progress_skip, num_ngrams());
01209       p.set_report_string("beta numerator terms:");
01210       while (it.goto_next_depth_first()) {
01211         p.step();
01212         if (get_value(m_pruned, it) > 0)
01213           continue;
01214         if (it.length() < 2)
01215           continue;
01216 
01217         Iterator bo_it(it);
01218         if (bo_it.goto_backoff_once()) {
01219           add_value(m_sum_xg_not_pruned, bo_it, get_count(it));
01220           add_value(m_sum_nonzero_xg_not_pruned, bo_it, 1);
01221         }
01222       }
01223       p.finish();
01224     }
01225 
01229     void compute_beta_denominator()
01230     {
01231       assert(m_beta_denominator.empty());
01232       m_beta_denominator.resize(m_counts.size() - 1);
01233       m_beta_denominator0 = 0;
01234       for (size_t i = 0; i < m_beta_denominator.size(); i++)
01235         m_beta_denominator.at(i).resize(m_counts.at(i).size());
01236 
01237       Iterator it = root();
01238       Progress p(m_progress_skip, num_ngrams());
01239       p.set_report_string("beta denominator:");
01240       while (it.goto_next_depth_first()) {
01241         p.step();
01242 
01243         if (it.length() <= m_beta_denominator.size()) {
01244           int pruned_counts;
01245           compute_active_children(it, &pruned_counts);
01246           set_value(m_beta_denominator, it, pruned_counts);
01247         }
01248 
01249         if (it.symbol() == m_sentence_start_id)
01250           continue;
01251 
01252         Iterator parent_it(it);
01253         parent_it.goto_parent();
01254 
01255 //        fprintf(stderr, "%10g %s\n", get_beta_numerator(it), 
01256 //                ngram_str(it.symbol_vec()).c_str());
01257 
01258         float numerator = get_beta_numerator(it);
01259         if (parent_it.is_root())
01260           m_beta_denominator0 += numerator;
01261         else
01262           add_value(m_beta_denominator, parent_it, numerator);
01263       }
01264       p.finish();
01265     }
01266 
01268     int compute_active_children(Iterator it, int *pruned_counts = NULL)
01269     {
01270       int num_active_children = 0;
01271       if (pruned_counts != NULL)
01272         (*pruned_counts) = 0;
01273       if (it.goto_first_child()) {
01274         do {
01275           if (it.symbol() == m_sentence_start_id)
01276             continue;
01277           if (is_pruned(it)) {
01278             if (pruned_counts != NULL)
01279               (*pruned_counts) += get_count(it);
01280             continue;
01281           }
01282               
01283           num_active_children++;
01284         } while (it.goto_next_sibling());
01285       }
01286       return num_active_children;
01287     }
01288 
01293     void compute_beta_interpolation_numerator()
01294     {
01295       assert(m_beta_interpolation_numerator.empty());
01296       m_beta_interpolation_numerator.resize(m_counts.size() - 1);
01297       for (size_t i = 0; i < m_beta_interpolation_numerator.size(); i++)
01298         m_beta_interpolation_numerator.at(i).resize(m_counts.at(i).size());
01299 
01300       m_beta_interpolation_numerator0 = 0;
01301       Progress p(m_progress_skip, num_ngrams());
01302       p.set_report_string("beta interpolation numerator:");
01303       Iterator it = root();
01304       do {
01305         p.step();
01306 
01307         if (it.length() > m_beta_interpolation_numerator.size())
01308           continue;
01309 
01310         if (!it.is_root()) {
01311           if (is_pruned(it)) {
01312             set_value(m_beta_interpolation_numerator, it, 1);
01313             continue;
01314           }
01315         }
01316 
01317         int pruned_counts = 0;
01318         int num_active_children = compute_active_children(it, &pruned_counts);
01319         double interpolation = get_beta_discount(it.length()) *
01320           num_active_children + pruned_counts;
01321 
01322         if (it.is_root())
01323           m_beta_interpolation_numerator0 = interpolation;
01324         else
01325           set_value(m_beta_interpolation_numerator, it, interpolation);
01326       } while (it.goto_next_depth_first());
01327       p.finish();
01328     }
01329 
01336     Iterator add(const std::vector<int> &vec, int value)
01337     {
01338       Iterator it = m_trie.insert(vec);
01339       add_value(m_counts, it, value);
01340       return it;
01341     }
01342 
01350     void set_discount(unsigned int order, float value)
01351     {
01352       m_discounts.resize(order + 1);
01353       m_discounts.at(order) = value;
01354     }
01355 
01361     float get_discount(unsigned int order) const
01362     {
01363       if (m_discounts.empty())
01364         throw bit::invalid_call(
01365           "bit::Kneser::get_discount() discount not set");
01366       if (order >= m_discounts.size())
01367         return m_discounts.back();
01368       return m_discounts.at(order);
01369     }
01370 
01376     float get_beta_discount(unsigned int order) const
01377     {
01378       if (m_beta_discounts.empty())
01379         throw bit::invalid_call(
01380           "bit::Kneser::get_beta_discount() discount not set");
01381       if (order >= m_beta_discounts.size())
01382         return m_beta_discounts.back();
01383       return m_beta_discounts.at(order);
01384     }
01385 
01386     float interpolation(const Iterator &it) const
01387     {
01388       assert(it.length() == 0 || it.symbol() != m_sentence_end_id);
01389       double nominator = sum_nonzero_gx(it) * get_discount(it.length());
01390       double denominator = sum_nonzero_xgx(it);
01391       return nominator / denominator;
01392     }
01393 
01394     float interpolation_abs(const Iterator &it) const
01395     {
01396       assert(it.length() == 0 || it.symbol() != m_sentence_end_id);
01397       double nominator = sum_nonzero_gx(it) * get_discount(it.length());
01398       double denominator = sum_gx(it);
01399       return nominator / denominator;
01400     }
01401 
01406     void set_count(const Iterator &it, u32 value)
01407     {
01408       set_value(m_counts, it, value);
01409     }
01410 
01411     std::string debug_sum_nonzero_xg_str()
01412     {
01413       std::string str;
01414       Iterator it = root();
01415       while (it.goto_next_depth_first()) {
01416         str.append(ngram_str(it.symbol_vec()));
01417         str.append("\t");
01418         str.append(str::fmt(64, "%d\n", sum_nonzero_xg(it)));
01419       }
01420       return str;
01421     }
01422 
01423     void debug_write_counts(FILE *file)
01424     {
01425       fprintf(file, "sum_nonzero_xg_not_pruned:\n");
01426       Iterator it = root();
01427       while (it.goto_next_depth_first()) {
01428         fprintf(file, "%s\t  bdenom=%.2f  bnum=%.2f\n", 
01429                 ngram_str(it.symbol_vec()).c_str(),
01430                 it.length() < 3 ? get_beta_denominator(it) : -1,
01431                 get_beta_numerator(it)
01432           );
01433                 
01434       }
01435     }
01436 
01437 
01438   private:
01439 
01440     float get_value(const std::vector<FloatVec> &arrays, const Iterator &it) const
01441     {
01442       unsigned int len = it.length();
01443       u32 index = it.symbol_index();
01444       return arrays.at(len - 1).at(index);
01445     }
01446 
01447     int get_value(const std::vector<IntVec> &arrays, const Iterator &it) const
01448     {
01449       unsigned int len = it.length();
01450       if (len > arrays.size())
01451         return 0;
01452       u32 index = it.symbol_index();
01453       const IntVec &vec = arrays.at(len - 1);
01454       if (index >= vec.size())
01455         return 0;
01456       return vec[index];
01457     }
01458 
01459     u32 get_value(const std::vector<Array> &arrays, const Iterator &it) const
01460     {
01461       unsigned int len = it.length();
01462       u32 index = it.symbol_index();
01463       if (len > arrays.size())
01464         return 0;
01465       const Array &array = arrays.at(len - 1);
01466       if (index >= array.num_elems())
01467         return 0;
01468       return array.get(index);
01469     }
01470 
01471     void set_value(std::vector<FloatVec> &arrays, const Iterator &it, 
01472                    float value)
01473     {
01474       unsigned int len = it.length();
01475       u32 index = it.symbol_index();
01476       arrays.at(len - 1).at(index) = value;
01477     }
01478 
01479     void set_value(std::vector<IntVec> &arrays, const Iterator &it, 
01480                    int value)
01481     {
01482       unsigned int len = it.length();
01483       if (len > arrays.size())
01484         arrays.resize(len);
01485       u32 index = it.symbol_index();
01486       IntVec &vec = arrays.at(len - 1);
01487       while (vec.size() <= index)
01488         vec.push_back(0);
01489       vec[index] = value;
01490     }
01491 
01492     void set_value(std::vector<Array> &arrays, const Iterator &it, u32 value)
01493     {
01494       unsigned int len = it.length();
01495       u32 index = it.symbol_index();
01496       if (len > arrays.size())
01497         arrays.resize(len);
01498       arrays.at(len - 1).set_grow_widen(index, value);
01499     }
01500 
01501     void add_value(std::vector<Array> &arrays, const Iterator &it, u32 value)
01502     {
01503       set_value(arrays, it, get_value(arrays, it) + value);
01504     }
01505 
01506     void add_value(std::vector<FloatVec> &arrays, const Iterator &it, 
01507                    float value)
01508     {
01509       set_value(arrays, it, get_value(arrays, it) + value);
01510     }
01511 
01512     void add_value(std::vector<IntVec> &arrays, const Iterator &it, 
01513                    int value)
01514     {
01515       set_value(arrays, it, get_value(arrays, it) + value);
01516     }
01517 
01518     void sub_value(std::vector<Array> &arrays, const Iterator &it, u32 value)
01519     {
01520       u32 old_value = get_value(arrays, it);
01521       if (old_value < value)
01522         throw bit::invalid_argument("bit::Kneser::sub_value() underflow");
01523       set_value(arrays, it, get_value(arrays, it) - value);
01524     }
01525 
01528 
01532     int m_d1_weight_model;
01533 
01536     int m_d1_model;
01537 
01539 
01541     int m_progress_skip;
01542     
01544     Trie m_trie;
01545 
01547     std::vector<int> m_num_ngrams;
01548 
01550     std::vector<IntVec> m_counts;
01551 
01553     std::vector<IntVec> m_sum_gx;
01554 
01556     int m_sum_gx0;
01557 
01559     std::vector<IntVec> m_sum_nonzero_xg;
01560 
01562     std::vector<IntVec> m_sum_nonzero_gx;
01563 
01565     int m_sum_nonzero_gx0;
01566 
01568     std::vector<IntVec> m_sum_nonzero_xgx;
01569 
01571     int m_sum_nonzero_xgx0;
01572 
01574     std::vector<float> m_discounts;
01575 
01579     std::vector<float> m_beta_discounts;
01580 
01582     SymbolMap m_symbol_map;
01583 
01585     std::string m_sentence_start_str;
01586         
01588     std::string m_sentence_end_str;
01589         
01591     int m_sentence_start_id;
01592 
01594     int m_sentence_end_id;
01595 
01598 
01599     std::vector<Array> m_pruned; 
01600     std::vector<FloatVec> m_d1;
01601     std::vector<FloatVec> m_d2;
01602     std::vector<IntVec> m_d2_norm;
01603     std::vector<IntVec> m_sum_xg_not_pruned;
01604     std::vector<IntVec> m_sum_nonzero_xg_not_pruned;
01605 
01608     std::vector<FloatVec> m_beta_denominator;
01609 
01612     float m_beta_denominator0;
01613 
01616     std::vector<FloatVec> m_beta_interpolation_numerator;
01617 
01619     float m_beta_interpolation_numerator0;
01620 
01621     struct OrderIndex {
01622       OrderIndex(unsigned int order = 0, u32 index = 0) 
01623         : order(order), index(index) { }
01624       unsigned int order;
01625       u32 index;
01626     };
01627 
01628     struct PruneCompare {
01629       const Kneser *k;
01630 
01631       PruneCompare(const Kneser *k)
01632         : k(k)
01633       {
01634       }
01635 
01636       bool operator()(const OrderIndex &a, const OrderIndex &b) const
01637       {
01638         float d2_a = k->m_d2.at(a.order).at(a.index);
01639         int d2_a_norm = k->m_d2_norm.at(a.order).at(a.index);
01640         float d2_b = k->m_d2.at(b.order).at(b.index);
01641         int d2_b_norm = k->m_d2_norm.at(b.order).at(b.index);
01642         return (d2_a / d2_a_norm) < (d2_b / d2_b_norm);
01643       }
01644     };
01645 
01646     struct D2Norm {
01647       float d2;
01648       int norm;
01649       float value;
01650 
01651       D2Norm() : d2(0), norm(0), value(0)
01652       { 
01653       }
01654 
01655       D2Norm(float d2, int norm) : d2(d2), norm(norm), value(d2 / norm)
01656       { 
01657       }
01658 
01659       void add(float d2, int norm)
01660       {
01661         this->d2 += d2;
01662         this->norm += norm;
01663         value = this->d2 / this->norm;
01664       }
01665 
01666       bool operator<(const D2Norm &p) const
01667       {
01668         return value < p.value;
01669       }
01670     };
01671 
01673   };
01674 
01675 }
01676 
01677 #endif /* KNESER_HH */

Generated on Mon Jan 8 15:51:03 2007 for bit by  doxygen 1.4.6