ThunderSVM
ThunderSVM: An Open-Source SVM Library on GPUs and CPUs
kernelmatrix.h
1 //
2 // Created by jiashuai on 17-9-19.
3 //
4 
5 #ifndef THUNDERSVM_KERNELMATRIX_H
6 #define THUNDERSVM_KERNELMATRIX_H
7 
8 #include "thundersvm.h"
9 #include "syncarray.h"
10 #include "dataset.h"
11 #include "svmparam.h"
12 
17 public:
23  explicit KernelMatrix(const DataSet::node2d &instances, SvmParam param);
24 
30  void get_rows(const SyncArray<int> &idx, SyncArray<float_type> &kernel_rows) const;
31 
37  void get_rows(const DataSet::node2d &instances, SyncArray<float_type> &kernel_rows) const;
38 
40  const SyncArray<float_type> &diag() const;
41 
43  size_t n_instances() const { return n_instances_; };
44 
46  size_t n_features() const { return n_features_; }
47 
49  size_t nnz() const {return nnz_;};//number of nonzero
50 private:
51  KernelMatrix &operator=(const KernelMatrix &) const;
52 
53  KernelMatrix(const KernelMatrix &);
54 
56  SyncArray<int> col_ind_;
57  SyncArray<int> row_ptr_;
59  SyncArray<float_type> self_dot_;
60  size_t nnz_;
61  size_t n_instances_;
62  size_t n_features_;
63  SvmParam param;
64 
65  void dns_csr_mul(const SyncArray<float_type> &dense_mat, int n_rows, SyncArray<float_type> &result) const;
66 
67  void get_dot_product(const SyncArray<int> &idx, SyncArray<float_type> &dot_product) const;
68 
69  void get_dot_product(const DataSet::node2d &instances, SyncArray<float_type> &dot_product) const;
70 };
71 #endif //THUNDERSVM_KERNELMATRIX_H
void get_rows(const SyncArray< int > &idx, SyncArray< float_type > &kernel_rows) const
Definition: kernelmatrix.cpp:71
The management class of kernel values.
Definition: kernelmatrix.h:16
size_t n_features() const
the maximum number of features of instances
Definition: kernelmatrix.h:46
const SyncArray< float_type > & diag() const
return the diagonal elements of kernel matrix
Definition: kernelmatrix.cpp:123
size_t nnz() const
the number of non-zero features of all instances
Definition: kernelmatrix.h:49
size_t n_instances() const
the number of instances in KernelMatrix
Definition: kernelmatrix.h:43
params for ThunderSVM
Definition: svmparam.h:13
KernelMatrix(const DataSet::node2d &instances, SvmParam param)
Definition: kernelmatrix.cpp:9