Jafar
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
jann.hpp
Go to the documentation of this file.
00001 
00006 #ifndef JMATH_JANN_HPP
00007 #define JMATH_JANN_HPP
00008 
00009 #include "jafarConfig.h"
00010 
00011 #ifdef HAVE_FLANN
00012 
00013 #include "jmath/jmathException.hpp"
00014 
00015 #include <flann/flann.hpp>
00016 #include <boost/numeric/ublas/vector.hpp>
00017 #include <boost/numeric/ublas/matrix.hpp>
00018 
00021 
00023 namespace ublas = boost::numeric::ublas;
00024 
00026 namespace jann {
00027   // /// algorithm to be used for research, use only if you want to create a new Index
00028   // enum algorithm {
00029   //  LINEAR        = 0,
00030   //  KDTREE        = 1,
00031   //  KMEANS        = 2,
00032   //  COMPOSITE     = 3,
00033   //  KDTREE_SINGLE = 4,
00034   //  SAVED         = 254,
00035   //  AUTOTUNED     = 255
00036   // };
00037   // /// algorithm to initialize centers for the K Means algorithm
00038   // enum centers_init {
00039   //  CENTERS_RANDOM   = 0,
00040   //  CENTERS_GONZALES = 1,
00041   //  CENTERS_KMEANSPP = 2
00042   // };
00043   // /// determine log level
00044   // enum log_level {
00045   //  LOG_NONE  = 0,
00046   //  LOG_FATAL = 1,
00047   //  LOG_ERROR = 2,
00048   //  LOG_WARN  = 3,
00049   //  LOG_INFO  = 4
00050   // };
00051   // /// supported distances
00052   // enum distance {
00053   //  EUCLIDEAN        = 1,
00054   //  MANHATTAN        = 2,
00055   //  MINKOWSKI        = 3,
00056   //  MAX_DIST         = 4,
00057   //  HIST_INTERSECT   = 5,
00058   //  HELLINGER        = 6,
00059   //  CS               = 7,
00060   //  CHI_SQUARE       = 7,
00061   //  KL               = 8,
00062   //  KULLBACK_LEIBLER = 8
00063   // };
00064 
00067   class search_params : public flann::SearchParams {
00068   public:
00074     search_params(int checks = 32, float eps = 0, bool sorted = true ) :
00075       flann::SearchParams(checks, eps, sorted){}    
00076   };
00077 
00083   template<typename D>
00084   class index_factory 
00085   {
00086     flann::Index<D> *m_index;
00088     template<typename T>
00089     inline void convert(const flann::Matrix<T>& flann_mat, ublas::matrix<T>& ublas_mat){
00090       JFR_ASSERT(((flann_mat.rows == ublas_mat.size1()) && 
00091                   (flann_mat.cols == ublas_mat.size2())),
00092                  "ublas matrix and flann matrix need to have the same sizes");
00093       for(size_t r = 0; r < flann_mat.rows; r++)
00094         for(size_t c = 0; c < flann_mat.cols; c++)
00095           ublas_mat(r,c) = flann_mat[r][c];
00096     }
00098     template<typename T>
00099     inline void convert(const ublas::matrix<T>& ublas_mat, flann::Matrix<T>& flann_mat){
00100       JFR_ASSERT(((flann_mat.rows == ublas_mat.size1()) && 
00101                   (flann_mat.cols == ublas_mat.size2())),
00102                  "ublas matrix and flann matrix need to have the same sizes");
00103       for(size_t r = 0; r < flann_mat.rows; r++)
00104         for(size_t c = 0; c < flann_mat.cols; c++)
00105           flann_mat[r][c] = ublas_mat(r,c);
00106     }
00110     template<typename T>
00111     inline void convert(const flann::Matrix<T>& flann_mat, ublas::vector<T>& ublas_vec){
00112       JFR_ASSERT(((flann_mat.rows == 1) && (flann_mat.cols >= ublas_vec.size())),
00113                  "ublas vector and flann matrix rows need to have the same sizes");
00114       if(flann_mat.cols > ublas_vec.size())
00115         ublas_vec.resize(flann_mat.cols);
00116       for(size_t counter = 0; counter < ublas_vec.size(); counter++)
00117           ublas_vec[counter] = flann_mat[0][counter];
00118     }
00120     template<typename T>
00121     inline void convert(const ublas::vector<T>& ublas_vec, flann::Matrix<T>& flann_mat){
00122       JFR_ASSERT(((flann_mat.rows == 1) && (flann_mat.cols >= ublas_vec.size())),
00123                  "ublas vector and flann matrix rows need to have the same sizes");
00124         for(size_t counter = 0; counter < ublas_vec.size(); counter++)
00125           flann_mat[0][counter] = ublas_vec[counter];
00126     }
00128     template<typename T>
00129     inline void convert(const std::vector<T>& std_vec, flann::Matrix<T>& flann_mat){
00130       JFR_ASSERT(((flann_mat.rows == 1) && (flann_mat.cols >= std_vec.size())),
00131                  "std vector and flann matrix rows need to have the same sizes");
00132         for(size_t counter = 0; counter < std_vec.size(); counter++)
00133           flann_mat[0][counter] = std_vec[counter];
00134     }
00135   public:
00137     typedef typename D::ElementType element;
00139     typedef typename D::ResultType result;
00141     flann::Matrix<element> dataset;
00143     index_factory() {}
00149     void operator()(const ublas::matrix<typename D::ElementType>& _data, 
00150                     const flann::IndexParams& params, D d = D() )
00151     {
00152       dataset = flann::Matrix<element>(new element[_data.size1()*_data.size2()], 
00153                                        _data.size1(), _data.size2());
00154       convert(_data, dataset);
00155       m_index = new flann::Index<D>(dataset, params, d);
00156     }
00162     index_factory(const ublas::matrix<typename D::ElementType>& _data, 
00163                   const flann::IndexParams& params, D d = D() )
00164     {
00165       dataset = flann::Matrix<element>(new element[_data.size1()*_data.size2()], 
00166                                        _data.size1(), _data.size2());
00167       convert(_data, dataset);
00168       m_index = new flann::Index<D>(dataset, params, d);
00169     }
00171     virtual ~index_factory() {
00172       dataset.free();
00173     }
00175     void build() {
00176       m_index->buildIndex();
00177     }
00179     void knn_search(const ublas::matrix<typename D::ElementType>& queries, 
00180                     ublas::matrix<int>& indices, 
00181                     ublas::matrix<result>& dists, int knn, 
00182                     const search_params& params) 
00183     {
00184       size_t rows = queries.size1();
00185       size_t cols = queries.size2();
00186       JFR_PRED_ERROR(cols == dataset.cols,
00187                      jafar::jmath::JmathException,
00188                      jafar::jmath::JmathException::WRONG_SIZE,
00189                      "queries and dataset need to have same columns size")
00190       JFR_PRED_ERROR(((indices.size2() >= (size_t)knn) && 
00191                       (dists.size2() >= (size_t)knn)),
00192                      jafar::jmath::JmathException,
00193                      jafar::jmath::JmathException::WRONG_SIZE,
00194                      "indices and dists must have at least "<<knn<<" columns")
00195       JFR_PRED_ERROR(((indices.size1() == rows) && (dists.size1() == rows)),
00196                      jafar::jmath::JmathException,
00197                      jafar::jmath::JmathException::WRONG_SIZE,
00198                      "queries, indices and dists need to be of same row size")
00199       flann::Matrix<element> _queries(new element[rows*cols], rows, cols);
00200       convert(queries, _queries);
00201       flann::Matrix<int> _indices(new int[rows*indices.size2()], rows, indices.size2());
00202       flann::Matrix<result> _dists(new result[rows*dists.size2()], rows, dists.size2());
00203       m_index->knnSearch(_queries, _indices, _dists, knn, params);
00204       convert(_indices, indices);
00205       convert(_dists, dists);
00206 
00207       _queries.free();
00208       _dists.free();
00209       _indices.free();
00210     }
00212     void knn_search(const ublas::vector<element>& query, 
00213                     ublas::vector<int>& indices, ublas::vector<result>& dists, 
00214                     int knn, const search_params& params) 
00215     {
00216       size_t length = query.size();
00217       JFR_PRED_ERROR(length == dataset.cols,
00218                      jafar::jmath::JmathException,
00219                      jafar::jmath::JmathException::WRONG_SIZE,
00220                      "query size must be of dataset columns size")
00221         JFR_PRED_ERROR(((indices.size() >= size_t(knn)) && (dists.size() >= size_t(knn))),
00222                      jafar::jmath::JmathException,
00223                      jafar::jmath::JmathException::WRONG_SIZE,
00224                      "indices and dists must be at least of size "<<knn)
00225       flann::Matrix<element> _query(new element[length], 1, length);
00226       convert(query,_query);
00227       flann::Matrix<int> _indices(new int[indices.size()], 1, indices.size());
00228       flann::Matrix<result> _dists(new result[dists.size()], 1, dists.size());
00229       m_index->knnSearch(_query, _indices, _dists, knn, params);
00230       convert(_indices, indices);
00231       convert(_dists, dists);
00232 
00233       _query.free();
00234       _dists.free();
00235       _indices.free();
00236     }
00238     void knn_search(const std::vector<element>& query, 
00239                     std::vector<int>& indices, std::vector<result>& dists, 
00240                     int knn, const search_params& params) 
00241     {
00242       size_t length = query.size();
00243       JFR_PRED_ERROR(length == dataset.cols,
00244                      jafar::jmath::JmathException,
00245                      jafar::jmath::JmathException::WRONG_SIZE,
00246                      "query size must be of dataset columns size")
00247         JFR_PRED_ERROR(((indices.size() >= size_t(knn)) && (dists.size() >= size_t(knn))),
00248                      jafar::jmath::JmathException,
00249                      jafar::jmath::JmathException::WRONG_SIZE,
00250                      "indices and dists must be at least of size "<<knn)
00251       flann::Matrix<element> _query(new element[length], 1, length);
00252       convert(query,_query);
00253       flann::Matrix<int> _indices(new int[indices.size()], 1, indices.size());
00254       flann::Matrix<result> _dists(new result[dists.size()], 1, dists.size());
00255       m_index->knnSearch(_query, _indices, _dists, knn, params);
00256       convert(_indices, indices);
00257       convert(_dists, dists);
00258 
00259       _query.free();
00260       _dists.free();
00261       _indices.free();
00262     }
00266     int radius_search(const ublas::vector<element>& query, 
00267                       ublas::matrix<int>& indices, 
00268                       ublas::matrix<result>& dists, float radius, 
00269                       const search_params& params) 
00270     {
00271       size_t length = query.size();
00272       JFR_PRED_ERROR(length == dataset.cols,
00273                      jafar::jmath::JmathException,
00274                      jafar::jmath::JmathException::WRONG_SIZE,
00275                      "query length and dataset columns must be equal")
00276       JFR_PRED_ERROR((indices.size2() == dists.size2()),
00277                      jafar::jmath::JmathException,
00278                      jafar::jmath::JmathException::WRONG_SIZE,
00279                      "indices and dists must have same columns number")
00280       flann::Matrix<element> _query(new element[length], 1, length);
00281       convert(query, _query);
00282       flann::Matrix<int> _indices(new int[indices.size1()*indices.size2()], indices.size1(), indices.size2());
00283       flann::Matrix<result> _dists(new result[dists.size1()*dists.size2()], dists.size1(), dists.size2());
00284       m_index->radiusSearch(_query, _indices, _dists, radius, params);
00285       convert(_indices, indices);
00286       convert(_dists, dists);
00287 
00288       _query.free();
00289       _dists.free();
00290       _indices.free();
00291     }
00295     int radius_search(const ublas::vector<element>& query,
00296                       ublas::vector<int>& indices, ublas::vector<result>& dists, 
00297                       float radius, const search_params& params) 
00298     {
00299       size_t length = query.size();
00300       JFR_PRED_ERROR(length == dataset.cols,
00301                      jafar::jmath::JmathException,
00302                      jafar::jmath::JmathException::WRONG_SIZE,
00303                      "query size must be of dataset columns size")
00304         JFR_PRED_ERROR((indices.size() == dists.size()),
00305                      jafar::jmath::JmathException,
00306                      jafar::jmath::JmathException::WRONG_SIZE,
00307                      "indices and dists must have same size")
00308       flann::Matrix<element> _query(new element[length], 1, length);
00309       convert(query,_query);
00310       flann::Matrix<int> _indices(new int[indices.size()], 1, indices.size());
00311       flann::Matrix<result> _dists(new result[dists.size()], 1, dists.size());
00312       int result = m_index->radiusSearch(_query, _indices, _dists, radius, params);
00313       convert(_indices, indices);
00314       convert(_dists, dists);
00315 
00316       _query.free();
00317       _dists.free();
00318       _indices.free();
00319       return result;
00320     }
00324     int radius_search(const std::vector<element>& query,
00325                       std::vector<int>& indices, std::vector<result>& dists, 
00326                       float radius, const search_params& params) 
00327     {
00328       size_t length = query.size();
00329       JFR_PRED_ERROR(length == dataset.cols,
00330                      jafar::jmath::JmathException,
00331                      jafar::jmath::JmathException::WRONG_SIZE,
00332                      "query size must be of dataset columns size")
00333         JFR_PRED_ERROR((indices.size() == dists.size()),
00334                      jafar::jmath::JmathException,
00335                      jafar::jmath::JmathException::WRONG_SIZE,
00336                      "indices and dists must have same size")
00337       flann::Matrix<element> _query(new element[length], 1, length);
00338       convert(query,_query);
00339       flann::Matrix<int> _indices(new int[indices.size()], 1, indices.size());
00340       flann::Matrix<result> _dists(new result[dists.size()], 1, dists.size());
00341       int result = m_index->radiusSearch(_query, _indices, _dists, radius, params);
00342       convert(_indices, indices);
00343       convert(_dists, dists);
00344 
00345       _query.free();
00346       _dists.free();
00347       _indices.free();
00348       return result;
00349     }
00350 
00351   public:
00353     void save(std::string filename) const
00354     {
00355       m_index->save(filename);
00356     }
00358     int data_size() const 
00359     {
00360       return m_index->veclen();
00361     }
00363     int size() const
00364     {
00365       return m_index->size();
00366     }
00368     flann::NNIndex<result>* index() 
00369     { 
00370       return m_index->nnIndex; 
00371     }
00373     const flann::IndexParams* parameters() { 
00374       return m_index->nnIndex->getParameters(); 
00375     }
00376   };
00377     
00378   // template<typename D>
00379   // class Index {
00380   // protected :
00381   //  SearchParams params;
00382   //  index_factory<D> m_index;
00383   // public:
00384   //  Index(const ublas::matrix<typename D::ElementType>& dataset, IndexParams*)
00385   //  void knn_search(const ublas::vector<element>& query, int knn,
00386   //                  ublas::vector<int>& indices, ublas::vector<result>& dists) 
00387   //  {
00388   //    m_index.knn_search(query, knn, indices, dists, params);
00389   //  }
00390   //  void knn_search(const ublas::matrix<typename D::ElementType>& query, int knn,
00391   //                 ublas::matrix<int>& indices, ublas::matrix<result>& dists) 
00392   //  {
00393   //    m_index.knn_search(query, knn, indices, dists, params);
00394   //  }
00395   //  int radius_search(const ublas::vector<element>& query, float radius,
00396   //                 ublas::vector<int>& indices, ublas::vector<result>& dists)
00397   //  {
00398   //    m_index.radius_search(query, radius, indices, dists, params);
00399   //  }
00400   //  int radius_search(const ublas::matrix<typename D::ElementType>& query, float radius,
00401   //                 ublas::matrix<int>& indices, ublas::matrix<result>& dists) 
00402   //  {
00403   //    m_index.radius_search(query, radius, indices, dists, params);
00404   //  }
00405   //  virtual ~Index() {}
00406   // }
00407     
00411   template<typename DISTANCE>
00412   class linear_index : public index_factory<DISTANCE> {
00413   public:
00414     linear_index(const ublas::matrix<typename DISTANCE::ElementType>& dataset) : 
00415       index_factory<DISTANCE>(dataset, flann::LinearIndexParams()) {}
00416   };
00417     
00421   template<typename DISTANCE>
00422   class KD_tree_index : public index_factory<DISTANCE> {
00423   public:
00425     KD_tree_index(const ublas::matrix<typename DISTANCE::ElementType>& dataset, 
00426                   int nb_trees = 4) : 
00427       index_factory<DISTANCE>(dataset, flann::KDTreeIndexParams(nb_trees)) {}
00428     KD_tree_index() {}
00429     void operator() (const ublas::matrix<typename DISTANCE::ElementType>& dataset, 
00430                      int nb_trees = 4) {
00431       index_factory<DISTANCE>::operator()(dataset, flann::KDTreeIndexParams(nb_trees));
00432     }
00433   };
00434     
00438   template<typename DISTANCE>
00439   class K_means_index : public index_factory<DISTANCE> {
00440   public:
00447     K_means_index(const ublas::matrix<typename DISTANCE::ElementType>& dataset,
00448                   int branching = 32, int iterations = 11, 
00449                   flann::flann_centers_init_t init = flann::CENTERS_RANDOM, 
00450                   float cb_index = 0.2 ) :
00451       index_factory<DISTANCE>(dataset, 
00452                               flann::KMeansIndexParams(branching, iterations, 
00453                                                        init, cb_index)) {}
00454   };
00455 
00459   template<typename DISTANCE>
00460   class composite_index : public index_factory<DISTANCE> {
00468     composite_index(const ublas::matrix<typename DISTANCE::ElementType>& dataset,
00469                     int trees = 4, int branching = 32, int iterations = 11,
00470                     flann::flann_centers_init_t init = flann::CENTERS_RANDOM, float cb_index = 0.2 ) :
00471       index_factory<DISTANCE>(dataset, 
00472                               flann::CompositeIndexParams(trees, branching,
00473                                                           iterations, init,
00474                                                           cb_index)) {}
00475   };
00476 
00480   template<typename DISTANCE>
00481   class autotuned_index : public index_factory<DISTANCE> {
00482   public:
00489     autotuned_index(const ublas::matrix<typename DISTANCE::ElementType>& dataset, 
00490                     float target_precision = 0.9, float build_weight = 0.01,
00491                     float memory_weight = 0, float sample_fraction = 0.1) :   
00492     
00493       index_factory<DISTANCE>(dataset,
00494                               flann::AutotunedIndexParams(target_precision, build_weight,
00495                                                           memory_weight, sample_fraction)) {}
00496   };
00497   
00498   // template<typename DISTANCE>
00499   // class saved_index : public index_factory<DISTANCE> {
00500   //  /**
00501   //   * @param filename: file where the index was stored
00502   //   */
00503   //   saved_index(const std::string& filename) 
00504 
00505   // };
00506 
00507 } // namespace jann
00508 
00510 /* End of Doxygen group */
00511 
00512 #endif // HAVE_FLANN
00513 #endif // JMATH_JANN_HPP
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

Generated on Wed Oct 15 2014 00:37:21 for Jafar by doxygen 1.7.6.1
LAAS-CNRS