validated matrix multiplication
This commit is contained in:
parent
75fb4c27f7
commit
d33a00721a
2 changed files with 21 additions and 7 deletions
|
@ -94,14 +94,13 @@ public:
|
||||||
|
|
||||||
matrix<R,C,T>& operator *= (const matrix<R,C,T>& rhs);
|
matrix<R,C,T>& operator *= (const matrix<R,C,T>& rhs);
|
||||||
|
|
||||||
matrix<R,C,T>& copyFrom(const T* src) { for (unsigned int i = 0; i < R*C; ++i) { (*this).at(i) = src[i]; } return *this; }
|
matrix<R,C,T>& copy_from_data(const T* src) { for (unsigned int i = 0; i < R*C; ++i) { (*this).at(i) = src[i]; } return *this; }
|
||||||
|
|
||||||
matrix<R,C,T> operator * (const matrix<R,C,T>& rhs) const {
|
matrix<R,C,T> operator * (const matrix<R,C,T>& rhs) const {
|
||||||
return mul(*this,rhs);
|
return mul(*this,rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
matrix<C,R,T>
|
const matrix<C,R,T> reshape() const {
|
||||||
reshape() const {
|
|
||||||
matrix<C,R,T> m;
|
matrix<C,R,T> m;
|
||||||
for (unsigned int r = 0; r < R; ++r)
|
for (unsigned int r = 0; r < R; ++r)
|
||||||
for (unsigned int c = 0; c < C; ++c)
|
for (unsigned int c = 0; c < C; ++c)
|
||||||
|
@ -109,12 +108,12 @@ public:
|
||||||
return m;
|
return m;
|
||||||
}
|
}
|
||||||
|
|
||||||
matrix<R,1,T> getColumn(unsigned int col) const {
|
const matrix<R,1,T> get_column(unsigned int col) const {
|
||||||
matrix<R,1,T> c; for (unsigned int r = 0; r < R; ++r) c(r,0) = (this)(r,col);
|
matrix<R,1,T> c; for (unsigned int r = 0; r < R; ++r) c(r,0) = (this)(r,col);
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
matrix<1,C,T> getRow(unsigned int row) const {
|
const matrix<1,C,T> get_row(unsigned int row) const {
|
||||||
matrix<1,C,T> r; for (unsigned int c = 0; c < C; ++c) r(0,c) = (this)(row,c);
|
matrix<1,C,T> r; for (unsigned int c = 0; c < C; ++c) r(0,c) = (this)(row,c);
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
@ -180,14 +179,15 @@ mul(const matrix<aR,aCbR,T>& A, const matrix<aCbR,bC,T>& B)
|
||||||
{
|
{
|
||||||
// aC == bR
|
// aC == bR
|
||||||
// set all null
|
// set all null
|
||||||
matrix<aR,bC,T> res = matrix<aR,bC,T>::All(0);
|
matrix<aR,bC,T> res;
|
||||||
|
res.fill(0);
|
||||||
|
|
||||||
// compute all resulting cells
|
// compute all resulting cells
|
||||||
for (unsigned int r = 0; r < aR; ++r) {
|
for (unsigned int r = 0; r < aR; ++r) {
|
||||||
for (unsigned int c = 0; c < bC; ++c) {
|
for (unsigned int c = 0; c < bC; ++c) {
|
||||||
// building inner product
|
// building inner product
|
||||||
for (unsigned int iI = 0; iI < aCbR;iI++) {
|
for (unsigned int iI = 0; iI < aCbR;iI++) {
|
||||||
res(r,c) += A(r,iI) * B(iI,c);
|
res.at(r,c) += A.at(r,iI) * B.at(iI,c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,5 +33,19 @@ int main(int argc,char **argv) {
|
||||||
std::cout << "mscale = " << pw::serialize::matrix(mscale) << std::endl;
|
std::cout << "mscale = " << pw::serialize::matrix(mscale) << std::endl;
|
||||||
|
|
||||||
|
|
||||||
|
pw::matrix44d a;
|
||||||
|
|
||||||
|
for (int r = 0; r < m.rows(); r++) {
|
||||||
|
for (int c = 0; c < m.cols(); c++) {
|
||||||
|
a.at(r,c) = r * m.cols() + c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << "a = " << pw::serialize::matrix(a) << std::endl;
|
||||||
|
|
||||||
|
pw::matrix44d r = a * mscale;
|
||||||
|
|
||||||
|
std::cout << "a * mscale = " << pw::serialize::matrix(r) << std::endl;
|
||||||
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue