LM.hh

Go to the documentation of this file.
00001 #ifndef LM_HH
00002 #define LM_HH
00003 
00004 #include "bit/Trie.hh"
00005 #include "bit/CompressedArray.hh"
00006 #include "bit/FloatArray.hh"
00007 #include "SymbolMap.hh"
00008 
00011 namespace bit {
00012 
00020   class LM {
00021   public:
00022 
00024     typedef Trie<CompressedArray> Trie;
00025 
00027     typedef Trie::Iterator Iterator;
00028 
00030     typedef SymbolMap<std::string, int> SymbolMap;
00031 
00033     LM()
00034     {
00035       reset();
00036     }
00037 
00039     void reset()
00040     {
00041       m_symbol_map = SymbolMap();
00042       m_start_symbol = -1;
00043       m_end_symbol = -1;
00044       m_trie = Trie();
00045       m_backoff_arrays.clear();
00046       m_score_arrays.clear();
00047       m_previous_ngram.clear();
00048     }
00049 
00051     unsigned int order() const
00052     {
00053       return m_score_arrays.size();
00054     }
00055 
00057     u64 size() const
00058     {
00059       u64 size = 0;
00060       assert(m_score_arrays.size() == m_backoff_arrays.size());
00061       for (size_t i = 0; i < m_score_arrays.size(); i++) {
00062         size += m_score_arrays[i].data_len();
00063         size += m_backoff_arrays[i].data_len();
00064       }
00065       return size + m_trie.size();
00066     }
00067 
00069     const FloatArray &score_array(unsigned int level) const
00070     {
00071       return m_score_arrays.at(level);
00072     }
00073 
00075     const FloatArray &backoff_array(unsigned int level) const
00076     {
00077       return m_backoff_arrays.at(level);
00078     }
00079 
00081     const CompressedArray &symbol_array(unsigned int level) const
00082     {
00083       return m_trie.symbol_array(level);
00084     }
00085 
00087     const CompressedArray &pointer_array(unsigned int level) const
00088     {
00089       return m_trie.pointer_array(level);
00090     }
00091 
00093     const CompressedArray &child_limit_array(unsigned int level) const
00094     {
00095       return m_trie.child_limit_array(level);
00096     }
00097 
00103     void
00104     read_arpa(FILE *file, const std::string &sentence_start_str = "<s>",
00105               const std::string &sentence_end_str = "</s>", 
00106               bool verbose = false);
00107 
00111     void write_arpa(FILE *file) const;
00112 
00117     void write(FILE *file) const;
00118 
00123     void read(FILE *file);
00124 
00133     void linear_quantization(unsigned int bits);
00134 
00138     void compress_trie(unsigned int level)
00139     {
00140       m_trie.compress(level);
00141     }
00142 
00144     void compress_trie()
00145     {
00146       m_trie.compress();
00147     }
00148 
00152     void uncompress_trie(unsigned int level)
00153     {
00154       m_trie.uncompress(level);
00155     }
00156 
00158     void uncompress_trie()
00159     {
00160       m_trie.uncompress();
00161     }
00162 
00171     void separate_leafs(unsigned int level);
00172 
00181     void unseparate_leafs(unsigned int level);
00182 
00191     void
00192     insert_ngram(const std::vector<int> &ngram, float score, float backoff);
00193 
00202     void
00203     insert_ngram(const std::string &str, float score, float backoff);
00204 
00209     void set_start_symbol(const std::string &str)
00210     {
00211       if (m_start_symbol >= 0)
00212         throw bit::invalid_call("bit::LM::set_start_symbol() called again");
00213       m_start_symbol = m_symbol_map.insert(str);
00214     }
00215 
00220     void set_end_symbol(const std::string &str)
00221     {
00222       if (m_end_symbol >= 0)
00223         throw bit::invalid_call("bit::LM::set_end_symbol() called again");
00224       m_end_symbol = m_symbol_map.insert(str);
00225     }
00226 
00228     int start_symbol() const { 
00229       return m_start_symbol;
00230     }
00231 
00233     int end_symbol() const { 
00234       return m_end_symbol;
00235     }
00236 
00238     const SymbolMap &symbol_map() const
00239     {
00240       return m_symbol_map;
00241     }
00242 
00247     template <class T>
00248     std::string 
00249     ngram_str(const std::vector<T> &vec) const
00250     {
00251       assert(!vec.empty());
00252       std::string str;
00253       for (size_t o = 0; o < vec.size(); o++) {
00254         if (o != 0)
00255           str.append(" ");
00256         str.append(m_symbol_map.at(vec[o]));
00257       }
00258       return str;
00259     }
00260 
00262     Iterator root() const
00263     {
00264       return Iterator(m_trie);
00265     }
00266 
00270     float backoff(const Iterator &it) const
00271     {
00272       if (it.is_root())
00273         throw invalid_call("lm::LM::backoff() called at root");
00274       u32 index = it.child_limit_index();
00275       if (index == max_u32)
00276         return 0;
00277       return backoff(it.length() - 1, index);
00278     }
00279 
00287     float backoff(unsigned int level, u64 index) const
00288     {
00289       if (level >= m_score_arrays.size())
00290         throw bit::invalid_argument("bit::LM::backoff() level too high");
00291       const FloatArray &score_array = m_score_arrays.at(level);
00292       const FloatArray &backoff_array = m_backoff_arrays.at(level);
00293       if (index >= score_array.num_elems())
00294         throw bit::invalid_argument("bit::LM::backoff() index too high");
00295       if (index >= backoff_array.num_elems())
00296         return 0;
00297       return backoff_array.get(index);
00298     }
00299 
00301     float score(const Iterator &it) const
00302     {
00303       u64 index = it.symbol_index();
00304       unsigned int level = it.length() - 1;
00305       return m_score_arrays.at(level).get(index);
00306     }
00307 
00316     float walk(Iterator &it, int symbol) const
00317     {
00318       float score = 0;
00319       while (!it.goto_child(symbol)) {
00320         assert(it.length() > 0);
00321         score += backoff(it);
00322         it.goto_backoff_full();
00323       }
00324       u64 index = it.symbol_index();
00325       score += m_score_arrays.at(it.length() - 1).get(index);
00326       if (it.num_children() == 0) {
00327         assert(backoff(it) == 0);
00328         it.goto_backoff_full();
00329       }
00330       return score;
00331     }
00332 
00333   private:
00334 
00338     int compare_ngrams(const std::vector<int> &a, const std::vector<int> &b)
00339     {
00340       std::vector<int>::size_type o = 0;
00341       while (1) {
00342         if (o == a.size() && o == b.size())
00343           return 0;
00344         if (o == a.size())
00345           return -1;
00346         if (o == b.size())
00347           return 1;
00348         if (a[o] < b[o])
00349           return -1;
00350         if (a[o] > b[o])
00351           return 1;
00352         o++;
00353       }
00354     }
00355 
00356   private:
00357 
00359     SymbolMap m_symbol_map;
00360 
00362     int m_start_symbol;
00363 
00365     int m_end_symbol;
00366 
00368     Trie m_trie;
00369 
00371     std::vector<FloatArray> m_backoff_arrays;
00372 
00374     std::vector<FloatArray> m_score_arrays;
00375 
00377     std::vector<int> m_previous_ngram;
00378   };
00379 
00380 };
00381 
00382 #endif /* LM_HH */

Generated on Mon Jan 8 15:51:03 2007 for bit by  doxygen 1.4.6