dbcan模板类的c++实现

dbcan模板类的c++实现,第1张

dbcan模板类的c++实现 dbcan模板类的c++实现

好久没更新了,写个聚类凑凑数=.=

  • 写了个dbscan的c++模板类,可以用来处理简单的聚类,耗时也很少
  • 由https://zh.wikipedia.org/wiki/DBSCAN中的算法实现,原理很简单,就不重复了
  • 使用了距离矩阵存储所有点的距离,减少重复的距离计算
  • 可选:添加OpenMP并行运算计算距离,但实测好像运行耗时没怎么减少,可能测试用例太简单
  • 其中主要的耗时应该还是距离矩阵的计算,有两重for循环增加耗时。如果是数据量比较大的话,可以使用kdtree等结构来优化耗时,kdtree查询最近点更方便,这样聚类时判断邻点是否满足要求时会更快一些
  • 可以找一点大一点的聚类数据来做聚类
  • 直接上代码了,后面就不介绍了,有注释应该好看
  • 项目代码 https://gitee.com/leox24/dbscan_template
  • 参考代码 https://github.com/buresu/DBSCAN
dbscan.h
#pragma once
#include 
#include 
#include 
#include 
#include 
#include 

using std::deque;
using std::vector;

struct Point2d
{
    float x = 0.0f;
    float y = 0.0f;
};

// namespace TBD
// {


template  class DBScan
{
  private:
    // minimum number of points in a cluster
    int _min_num;

    // maximum distance threshold of two points
    double _eps;

    // cluster attributes of each point
    vector _labels;

    // save the distance between all points
    Eigen::MatrixXd _distance_mat;

    // TODO: you should define the distance function of your struct/class
    
    inline double GetDistance(const Eigen::Vector2d& p1, const Eigen::Vector2d& p2)
    {
        return (p1 -p2).norm();
    }
    inline double GetDistance(const Point2d& p1, const Point2d& p2)
    {
        return sqrt(std::pow((p1.x - p2.x), 2) + pow((p1.y - p2.y), 2));
    }

    
    int Run(const vector &all_pts);

    
    bool ExpandCluster(int pt_idx, int &cluster_idx);

  public:
    
    DBScan(float eps, int min_num) : _eps(eps), _min_num(min_num){}

    
    ~DBScan(){}

    
    vector> GetClusters(std::vector all_pts);

    
    vector GetLabels(){ return _labels;}
};

template  int DBScan::Run(const vector &all_pts)
{
    int size = all_pts.size();
    int cluster_idx = 1;

    // -1:noise, 0:unlabel, >1:cluster index
    _labels.resize(size, 0);

    // 1. calculate the distance between each two points, opt(using OpenMP if data is big)
    _distance_mat = Eigen::MatrixXd::Zero(size, size);
// #pragma omp parallel for schedule(runtime)
    for (int i = 0; i < size; ++i)
    {
        for (int j = i; j < size; ++j)
        {
            if (i != j)
            {
                _distance_mat(i, j) = GetDistance(all_pts.data()[i], all_pts.data()[j]);
                _distance_mat(j, i) = _distance_mat(i, j);
            }
        }
    }
    // std::cout << _distance_mat << std::endl;

    // 2. do clustering
    for (size_t i = 0; i < size; i++)
    {
        if (_labels[i] != 0)
            continue;
        ExpandCluster(i, cluster_idx);
    }

    return cluster_idx - 1;
}

template  bool DBScan::ExpandCluster(int pt_idx, int &cluster_idx)
{
    // 1.region query
    deque seeds_idx;
    for (size_t col = 0; col < _distance_mat.cols(); col++)
    {
        if (_distance_mat(pt_idx, col) < _eps)
        {
            seeds_idx.emplace_back(col);
        }
    }

    // 2.check point numbers of neighbors, whether its noise
    if (seeds_idx.size() < _min_num)
    {
        _labels[pt_idx] = -1;
        return false;
    }
    // 3.label point
    for (size_t i = 0; i < seeds_idx.size(); i++)
    {
        _labels[seeds_idx[i]] = cluster_idx;
    }

    // 4.do neighbors clustering
    seeds_idx.pop_front();
    while (!seeds_idx.empty())
    {
        auto &row = seeds_idx.front();
        // region query
        vector temp_idx;
        temp_idx.reserve(_distance_mat.cols());
        for (size_t col = 0; col < _distance_mat.cols(); col++)
        {
            if (_distance_mat(row, col) <= _eps)
                temp_idx.emplace_back(col);
        }
        // check point numbers of neighbors, whether its noise
        if (temp_idx.size() > _min_num)
        {
            // label point
            for (size_t i = 0; i < temp_idx.size(); i++)
            {
                // if already label, abort
                if (_labels[temp_idx[i]] >= 1)
                    continue;
                if (_labels[temp_idx[i]] == 0)
                    seeds_idx.emplace_back(temp_idx[i]);
                _labels[temp_idx[i]] = cluster_idx;
            }
        }

        seeds_idx.pop_front();
    }

    cluster_idx++;
    return true;
}
template  vector> DBScan::GetClusters(std::vector all_pts)
{
    int clusters_num = this->Run(all_pts);
    vector> result(clusters_num + 1);

    for (size_t i = 0; i < _labels.size(); i++)
    {
        auto &label = _labels[i];
        if (label < 1)
        {
            // noise point, index=0
            result[label + 1].emplace_back(all_pts[i]);
            continue;
        }

        result[label].emplace_back(all_pts[i]);
    }

    return result;
}
// } // namespace TBD
main.cc
#include 
#include 
#include 
#include "dbscan.h"
using namespace std;

TEST(testcase, test_dbscan)
{
    vector points(10);
    vector > result;

    points[0].x() = 20; points[0].y() = 21;
    points[1].x() = 20; points[1].y() = 25;
    points[2].x() = 28; points[2].y() = 22;
    points[3].x() = 30; points[3].y() = 52;
    points[4].x() = 26; points[4].y() = 70;
    points[5].x() = 30; points[5].y() = 75;
    points[6].x() = 0;  points[6].y() = 70;
    points[7].x() = 70; points[7].y() = 50;
    points[8].x() = 67; points[8].y() = 69;
    points[9].x() = 80; points[9].y() = 35;

    result.resize(4);
    result[0].emplace_back(points[6]);
    result[1].emplace_back(points[0]);
    result[1].emplace_back(points[1]);
    result[1].emplace_back(points[2]);
    result[2].emplace_back(points[3]);
    result[2].emplace_back(points[4]);
    result[2].emplace_back(points[5]);
    result[3].emplace_back(points[7]);
    result[3].emplace_back(points[8]);
    result[3].emplace_back(points[9]);

    DBScan dbscan(20.0, 3);
    vector> res = dbscan.GetClusters(points);

    cout << "noise " << endl;
    for(auto& r : res[0])
        cout << r.transpose() << endl;
    for (size_t i = 1; i < res.size(); i++)
    {
        cout << "cluster " << i << endl;
        for (size_t j = 0; j < res[i].size(); j++)
        {
            cout << res[i][j].transpose() << endl;
        }

    }

    EXPECT_EQ(result, res);
}

int main(int argc, char** argv) {

    testing::InitGoogleTest(&argc, argv);
    return RUN_ALL_TESTS();
    return 0;
}
CMakeLists.txt
cmake_minimum_required(VERSION 2.8)
project(dbscan)
set(CMAKE_BUILD_TYPE RELEASE)
add_compile_options(-std=c++11)
set( CMAKE_CXX_FLAGS "-std=c++11 -O2")

find_package(OpenMP REQUIRED)

include_directories("./src")
include(GoogleTest)

file(GLOB_RECURSE PROJECT_HEADERS "src/*.h" "src/*.hpp")
file(GLOB_RECURSE PROJECT_SOURCES "src/*.cc" "src/*.cpp")

add_executable(${PROJECT_NAME} ${PROJECT_HEADERS} ${PROJECT_SOURCES})
target_link_libraries(${PROJECT_NAME} OpenMP::OpenMP_CXX gtest pthread)


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/4751636.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-11-08
下一篇 2022-11-08

发表评论

登录后才能评论

评论列表(0条)

保存