00001 #ifndef _MEAN_SHIFT_CLUSTERING_
00002 #define _MEAN_SHIFT_CLUSTERING_
00003
00004 #include "jafarConfig.h"
00005 #ifdef HAVE_FLANN
00006
00007 #include "jmath/jann.hpp"
00008 #include "jmath/jblas.hpp"
00009 #include <set>
00010 namespace jafar {
00011 namespace jmath {
00021 template<typename NUMTYPE, size_t D>
00022 class mean_shift_clustering {
00024 public:
00025 typedef ublas::bounded_vector<NUMTYPE, D> vector_type;
00026 struct cluster {
00028 vector_type center;
00030 std::vector<int> inliers;
00031 cluster() : inliers(0) {}
00033 cluster(const vector_type& _center) : center(_center), inliers(0) {}
00034 cluster(const vector_type& _center, const std::vector<int> &_inliers) :
00035 center(_center), inliers(_inliers) {}
00036 cluster(const vector_type& _center, int* _inliers, size_t nb_inliers) :
00037 center(_center)
00038 {
00039 this->add(_inliers, nb_inliers);
00040 }
00042 void add(size_t pt) {
00043 inliers.push_back(pt);
00044 }
00045 template <class InputIterator>
00046 void add(InputIterator first, InputIterator last) {
00047 inliers.insert(inliers.end(), first, last);
00048 }
00049 void add(int* _inliers, size_t nb_inliers) {
00050 for(size_t i = 0; i < nb_inliers; i++, _inliers++)
00051 inliers.push_back(*_inliers);
00052 }
00053 };
00054 private:
00056 void normal_center(const std::vector<vector_type> &points,
00057 const vector_type& mean,
00058 vector_type& center)
00059 {
00060 JFR_DEBUG("nb of points "<<points.size());
00061 center = ublas::zero_vector<NUMTYPE>(D);
00062 for(typename std::vector<vector_type>::const_iterator pt = points.begin();
00063 pt != points.end();
00064 ++pt) {
00065 vector_type difference = *pt - mean;
00066 NUMTYPE n = exp(-sum_2<NUMTYPE>(difference)/(2.0 * gaussian_variance));
00067
00068
00069 difference*= n;
00070 center+= difference;
00071 }
00072 center/=NUMTYPE(points.size());
00073 center+=mean;
00074 }
00076 void uniform_center(const std::vector<vector_type> &points,
00077 vector_type& center)
00078 {
00079 center = ublas::zero_vector<NUMTYPE>(D);
00080 for(typename std::vector<vector_type>::const_iterator pt = points.begin();
00081 pt != points.end();
00082 ++pt) {
00083 center+= *pt;
00084 }
00085 center/=NUMTYPE(points.size());
00086 }
00087
00088 public:
00090 enum KERNEL_TYPE {NORMAL, UNIFORM};
00092 mean_shift_clustering(NUMTYPE _radius = 1.0,
00093 NUMTYPE _threshold = 0.1,
00094 NUMTYPE _min_distance = 0.1,
00095 unsigned int _max = 100,
00096 mean_shift_clustering::KERNEL_TYPE _kernel = mean_shift_clustering::NORMAL,
00097 NUMTYPE _variance = 1.0) :
00098 window_radius(_radius), convergence_threshold(_threshold),
00099 min_distance_between_clusters(_min_distance), max_iterations(_max),
00100 kernel(_kernel), gaussian_variance(_variance) {}
00104 size_t run(const ublas::matrix<NUMTYPE>& data)
00105 {
00106 using namespace jafar::jmath;
00107 using namespace std;
00108
00109 JFR_ASSERT(data.size1() > 1, "need at least two points");
00110 JFR_ASSERT(data.size2() == D, "points must be of dimension D");
00111 jann::KD_tree_index< flann::L2<NUMTYPE> > points_tree(data, 4);
00112 points_tree.build();
00113 vector_type current_center;
00114 vector_type new_center;
00115 for(size_t i = 0; i < data.size1(); i++) {
00116 current_center = row(data, i);
00117 unsigned int iter = 0;
00118 vector_type this_difference;
00119 do {
00120 ublas::vector<int> this_indices;
00121 ublas::vector<NUMTYPE> this_distances;
00122 int nb_neighbours = points_tree.radius_search(current_center, this_indices, this_distances, window_radius, jann::search_params(128, 0, true));
00123
00124 JFR_DEBUG("neighbours found: " << nb_neighbours);
00125 if(nb_neighbours > 0) {
00126 this_indices.resize(nb_neighbours, true);
00127 std::vector<vector_type> this_neighbours;
00128 this_neighbours.reserve(nb_neighbours);
00129
00130 for(ublas::vector<int>::const_iterator index = this_indices.begin();
00131 index != this_indices.end();
00132 ++index) {
00133 this_neighbours.push_back(row(data, *index));
00134 }
00135
00136 switch(kernel) {
00137 case NORMAL :
00138 normal_center(this_neighbours, current_center, new_center);
00139 break;
00140 case UNIFORM :
00141 uniform_center(this_neighbours, new_center);
00142 break;
00143 default :
00144 JFR_RUN_TIME("don't know about this kernel")
00145 break;
00146 }
00147 this_difference = new_center - current_center;
00148 current_center = new_center;
00149 iter++;
00150 }
00151 }while((iter < max_iterations) && (ublas::norm_2(this_difference) > convergence_threshold));
00152
00153 if(clusters.size() == 0) {
00154 JFR_DEBUG("added first cluster");
00155 clusters[0] = cluster(current_center);
00156
00157 continue;
00158 } else {
00159
00160 bool found = false;
00161 typename std::map<size_t, cluster>::iterator cit;
00162 for(cit = clusters.begin(); cit != clusters.end(); ++cit) {
00163
00164 if(distance_2<NUMTYPE>(current_center, cit->second.center) < min_distance_between_clusters){
00165 found = true;
00166 break;
00167 }
00168 }
00169 if(!found)
00170
00171 clusters[clusters.size()] = cluster(current_center);
00172
00173
00174 }
00175 }
00176
00177 ublas::matrix<NUMTYPE> clusters_centers(clusters.size(),D);
00178 for(typename std::map<size_t, cluster>::const_iterator cit = clusters.begin();
00179 cit != clusters.end();
00180 ++cit) {
00181 row(clusters_centers, cit->first) = cit->second.center;
00182 }
00183 jann::KD_tree_index< flann::L2<NUMTYPE> > clusters_tree(clusters_centers, 1);
00184 clusters_tree.build();
00185 for(size_t i = 0; i < data.size1(); i++) {
00186 ublas::vector<int> index; index.resize(1);
00187 ublas::vector<NUMTYPE> dist; dist.resize(1);
00188 clusters_tree.knn_search(row(data,i), index, dist, 1, jann::search_params(-1));
00189 if(index[0] != -1) {
00190
00191 clusters[index[0]].add(i);
00192 }
00193 }
00194 return clusters.size();
00195 }
00196
00198 std::map<size_t, cluster> found_clusters() const {
00199 return clusters;
00200 }
00201 private:
00203 std::map<size_t, cluster> clusters;
00205 NUMTYPE window_radius;
00207 NUMTYPE convergence_threshold;
00209 NUMTYPE min_distance_between_clusters;
00211 unsigned int max_iterations;
00213 KERNEL_TYPE kernel;
00215 NUMTYPE gaussian_variance;
00216
00221 template <typename T>
00222 T sum_2(const ublas::vector<T>& v) {
00223 T sum = 0;
00224 for(size_t i = 0; i < v.size(); i++)
00225 sum+= v[i] * v[i];
00226 return sum;
00227 }
00228
00233 template <typename T>
00234 T distance_2(const ublas::vector<T>& v1, const ublas::vector<T>& v2) {
00235 JFR_ASSERT(v1.size() == v2.size(),
00236 "mean_shift_clustering::distance_2: v1 and v2 sizes differ");
00237 ublas::vector<T> dif = v1 -v2;
00238 return sum_2(dif);
00239 }
00240 };
00241 }
00242 }
00243 #endif // HAVE_FLANN
00244 #endif