From d33a00721ab563c8832f45f29640b254df7c6dd4 Mon Sep 17 00:00:00 2001 From: Hartmut Seichter Date: Wed, 4 Apr 2018 11:33:19 +0200 Subject: [PATCH] validated matrix multiplication --- src/core/include/pw/core/matrix.hpp | 14 +++++++------- src/core/tests/pwcore_test_matrix.cpp | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/core/include/pw/core/matrix.hpp b/src/core/include/pw/core/matrix.hpp index 5e5a27f..cf21970 100644 --- a/src/core/include/pw/core/matrix.hpp +++ b/src/core/include/pw/core/matrix.hpp @@ -94,14 +94,13 @@ public: matrix& operator *= (const matrix& rhs); - matrix& copyFrom(const T* src) { for (unsigned int i = 0; i < R*C; ++i) { (*this).at(i) = src[i]; } return *this; } + matrix& copy_from_data(const T* src) { for (unsigned int i = 0; i < R*C; ++i) { (*this).at(i) = src[i]; } return *this; } matrix operator * (const matrix& rhs) const { return mul(*this,rhs); } - matrix - reshape() const { + const matrix reshape() const { matrix m; for (unsigned int r = 0; r < R; ++r) for (unsigned int c = 0; c < C; ++c) @@ -109,12 +108,12 @@ public: return m; } - matrix getColumn(unsigned int col) const { + const matrix get_column(unsigned int col) const { matrix c; for (unsigned int r = 0; r < R; ++r) c(r,0) = (this)(r,col); return c; } - matrix<1,C,T> getRow(unsigned int row) const { + const matrix<1,C,T> get_row(unsigned int row) const { matrix<1,C,T> r; for (unsigned int c = 0; c < C; ++c) r(0,c) = (this)(row,c); return r; } @@ -180,14 +179,15 @@ mul(const matrix& A, const matrix& B) { // aC == bR // set all null - matrix res = matrix::All(0); + matrix res; + res.fill(0); // compute all resulting cells for (unsigned int r = 0; r < aR; ++r) { for (unsigned int c = 0; c < bC; ++c) { // building inner product for (unsigned int iI = 0; iI < aCbR;iI++) { - res(r,c) += A(r,iI) * B(iI,c); + res.at(r,c) += A.at(r,iI) * B.at(iI,c); } } } diff --git a/src/core/tests/pwcore_test_matrix.cpp b/src/core/tests/pwcore_test_matrix.cpp index b9e3501..a8cdd63 100644 --- a/src/core/tests/pwcore_test_matrix.cpp +++ b/src/core/tests/pwcore_test_matrix.cpp @@ -33,5 +33,19 @@ int main(int argc,char **argv) { std::cout << "mscale = " << pw::serialize::matrix(mscale) << std::endl; + pw::matrix44d a; + + for (int r = 0; r < m.rows(); r++) { + for (int c = 0; c < m.cols(); c++) { + a.at(r,c) = r * m.cols() + c; + } + } + std::cout << "a = " << pw::serialize::matrix(a) << std::endl; + + pw::matrix44d r = a * mscale; + + std::cout << "a * mscale = " << pw::serialize::matrix(r) << std::endl; + + return 0; }