Jafar
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
mean_shift_clustering.hpp
00001 #ifndef _MEAN_SHIFT_CLUSTERING_
00002 #define _MEAN_SHIFT_CLUSTERING_
00003 
00004 #include "jafarConfig.h"
00005 #ifdef HAVE_FLANN
00006 
00007 #include "jmath/jann.hpp"
00008 #include "jmath/jblas.hpp"
00009 #include <set>
00010 namespace jafar {
00011   namespace jmath {
00021     template<typename NUMTYPE, size_t D>
00022     class mean_shift_clustering {
00024     public:
00025       typedef ublas::bounded_vector<NUMTYPE, D> vector_type;
00026       struct cluster {
00028         vector_type center;
00030         std::vector<int> inliers;
00031         cluster() : inliers(0) {}
00033         cluster(const vector_type& _center) : center(_center), inliers(0) {}
00034         cluster(const vector_type& _center, const std::vector<int> &_inliers) : 
00035           center(_center), inliers(_inliers) {}
00036         cluster(const vector_type& _center, int* _inliers, size_t nb_inliers) :
00037           center(_center) 
00038         {
00039           this->add(_inliers, nb_inliers);
00040         }
00042         void add(size_t pt) {
00043           inliers.push_back(pt);
00044         }
00045         template <class InputIterator>
00046         void add(InputIterator first, InputIterator last) {
00047           inliers.insert(inliers.end(), first, last);
00048         }
00049         void add(int* _inliers, size_t nb_inliers) {
00050           for(size_t i = 0; i < nb_inliers; i++, _inliers++)
00051             inliers.push_back(*_inliers);
00052         }
00053       };
00054     private:
00056       void normal_center(const std::vector<vector_type> &points, 
00057                          const vector_type& mean, 
00058                          vector_type& center)
00059       {
00060         JFR_DEBUG("nb of points "<<points.size());
00061           center = ublas::zero_vector<NUMTYPE>(D);
00062         for(typename std::vector<vector_type>::const_iterator pt = points.begin();
00063             pt != points.end();
00064             ++pt) {
00065           vector_type difference = *pt - mean;
00066           NUMTYPE n = exp(-sum_2<NUMTYPE>(difference)/(2.0 * gaussian_variance));
00067           // NUMTYPE n = exp(-sum_2(difference)/(2.0 * gaussian_variance));
00068           //    NUMTYPE n = exp(-distance_2(*pt,mean)/(2.0 * gaussian_variance));   
00069           difference*= n;
00070           center+= difference;
00071         }
00072         center/=NUMTYPE(points.size());
00073         center+=mean;
00074       }
00076       void uniform_center(const std::vector<vector_type> &points, 
00077                           vector_type& center)
00078       {
00079         center = ublas::zero_vector<NUMTYPE>(D);
00080         for(typename std::vector<vector_type>::const_iterator pt = points.begin();
00081             pt != points.end();
00082             ++pt) {
00083           center+= *pt;
00084         }
00085         center/=NUMTYPE(points.size());
00086       }
00087 
00088     public:
00090       enum KERNEL_TYPE {NORMAL, UNIFORM};
00092       mean_shift_clustering(NUMTYPE _radius = 1.0,
00093                             NUMTYPE _threshold = 0.1,
00094                             NUMTYPE _min_distance = 0.1,
00095                             unsigned int _max = 100,
00096                             mean_shift_clustering::KERNEL_TYPE _kernel = mean_shift_clustering::NORMAL,
00097                             NUMTYPE _variance = 1.0) :
00098         window_radius(_radius), convergence_threshold(_threshold),
00099         min_distance_between_clusters(_min_distance), max_iterations(_max),
00100         kernel(_kernel), gaussian_variance(_variance) {}
00104       size_t run(const ublas::matrix<NUMTYPE>& data) 
00105       {
00106         using namespace jafar::jmath;
00107         using namespace std;
00108  
00109         JFR_ASSERT(data.size1() > 1, "need at least two points");
00110           JFR_ASSERT(data.size2() == D, "points must be of dimension D");
00111           jann::KD_tree_index< flann::L2<NUMTYPE> > points_tree(data, 4);
00112         points_tree.build();
00113         vector_type current_center;
00114         vector_type new_center;
00115         for(size_t i = 0; i < data.size1(); i++) {
00116           current_center = row(data, i);
00117           unsigned int iter = 0;
00118           vector_type this_difference;
00119           do {
00120             ublas::vector<int> this_indices;
00121             ublas::vector<NUMTYPE> this_distances;
00122             int nb_neighbours = points_tree.radius_search(current_center, this_indices, this_distances, window_radius, jann::search_params(128, 0, true));
00123             //found some neighbours
00124             JFR_DEBUG("neighbours found: " << nb_neighbours);
00125               if(nb_neighbours > 0) {
00126                 this_indices.resize(nb_neighbours, true);
00127                 std::vector<vector_type> this_neighbours;
00128                 this_neighbours.reserve(nb_neighbours);
00129                 //fill neighbours
00130                 for(ublas::vector<int>::const_iterator index = this_indices.begin();
00131                     index != this_indices.end();
00132                     ++index) {
00133                   this_neighbours.push_back(row(data, *index));
00134                 }
00135                 //compute new center
00136                 switch(kernel) {
00137                 case NORMAL : 
00138                   normal_center(this_neighbours, current_center, new_center);
00139                   break;
00140                 case UNIFORM :
00141                   uniform_center(this_neighbours, new_center);
00142                   break;
00143                 default :
00144                   JFR_RUN_TIME("don't know about this kernel")
00145                     break;
00146                 }
00147                 this_difference = new_center - current_center;
00148                 current_center = new_center;
00149                 iter++;
00150               }
00151           }while((iter < max_iterations) && (ublas::norm_2(this_difference) > convergence_threshold));
00152           //if the clusters are empty add the found center as a cluster center
00153           if(clusters.size() == 0) {
00154             JFR_DEBUG("added first cluster");
00155               clusters[0] = cluster(current_center);
00156             //        clusters[0] = cluster(current_center, &this_indices[0], this_indices.size());
00157             continue;
00158           }  else {
00159             //check validity of found cluster center
00160             bool found = false;
00161             typename std::map<size_t, cluster>::iterator cit;
00162             for(cit = clusters.begin(); cit != clusters.end(); ++cit) {
00163               //              JFR_DEBUG("dist_2 " << distance_2<NUMTYPE>(current_center, cit->second.center))
00164               if(distance_2<NUMTYPE>(current_center, cit->second.center) < min_distance_between_clusters){
00165                 found = true;
00166                 break;
00167               }
00168             }
00169             if(!found)
00170               //        clusters[clusters.size()] = cluster(current_center, &this_indices[0], this_indices.size());
00171               clusters[clusters.size()] = cluster(current_center);
00172             // else 
00173             //  cit->second.add(&this_indices[0], this_indices.size());
00174           }
00175         }//end of find clusters
00176         //prune the data
00177         ublas::matrix<NUMTYPE> clusters_centers(clusters.size(),D);
00178         for(typename std::map<size_t, cluster>::const_iterator cit = clusters.begin();
00179             cit != clusters.end();
00180             ++cit) {
00181           row(clusters_centers, cit->first) = cit->second.center;
00182         }
00183         jann::KD_tree_index< flann::L2<NUMTYPE> > clusters_tree(clusters_centers, 1);
00184         clusters_tree.build();
00185         for(size_t i = 0; i < data.size1(); i++) {
00186           ublas::vector<int> index; index.resize(1);
00187           ublas::vector<NUMTYPE> dist; dist.resize(1);
00188           clusters_tree.knn_search(row(data,i), index, dist, 1, jann::search_params(-1));
00189           if(index[0] != -1) {
00190             //            JFR_DEBUG("Assigned pt "<<i<<" to cluster "<<index[0])
00191               clusters[index[0]].add(i);
00192           }
00193         }
00194         return clusters.size();
00195       }
00196 
00198       std::map<size_t, cluster> found_clusters() const {
00199         return clusters;
00200       }
00201     private:
00203       std::map<size_t, cluster> clusters;
00205       NUMTYPE window_radius;
00207       NUMTYPE convergence_threshold;
00209       NUMTYPE min_distance_between_clusters;
00211       unsigned int max_iterations;
00213       KERNEL_TYPE kernel;
00215       NUMTYPE gaussian_variance;
00216 
00221       template <typename T>
00222       T sum_2(const ublas::vector<T>& v) {
00223         T sum = 0;
00224         for(size_t i = 0; i < v.size(); i++)
00225           sum+= v[i] * v[i];
00226         return sum;
00227       }
00228       
00233       template <typename T>
00234       T distance_2(const ublas::vector<T>& v1, const ublas::vector<T>& v2) {
00235         JFR_ASSERT(v1.size() == v2.size(), 
00236                    "mean_shift_clustering::distance_2: v1 and v2 sizes differ");
00237           ublas::vector<T> dif = v1 -v2;
00238         return sum_2(dif);
00239       }
00240     };
00241   }
00242 }
00243 #endif // HAVE_FLANN
00244 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

Generated on Wed Oct 15 2014 00:37:21 for Jafar by doxygen 1.7.6.1
LAAS-CNRS