paradiso/src/lib/matrixbase.hpp

116 lines
3.4 KiB
C++
Raw Normal View History

#ifndef PARADISO_MATRIXBASE_HPP
#define PARADISO_MATRIXBASE_HPP
#include <cmath>
#include <numeric>
namespace paradiso {
template <typename Scalar, typename Derived> struct MatrixBase {
using value_type = Scalar;
using reference = Scalar&;
using const_reference = const Scalar&;
using pointer = Scalar*;
using const_pointer = const Scalar*;
using iterator = Scalar*;
using const_iterator = const Scalar*;
constexpr pointer data() { return &derived().data[0]; }
constexpr const_pointer data() const noexcept { return &derived().data[0]; }
constexpr iterator begin() noexcept { return this->data(); }
constexpr iterator end() noexcept { return this->data() + size(); }
constexpr const_iterator begin() const noexcept { return this->data(); }
constexpr const_iterator end() const noexcept { return this->data() + size(); }
constexpr const_iterator cbegin() const noexcept { return this->data(); }
constexpr const_iterator cend() const noexcept { return this->data() + size(); }
constexpr Scalar& operator[](std::size_t i) { return this->data()[i]; }
constexpr const Scalar& operator[](std::size_t i) const {
return this->data()[i];
}
Derived& derived() { return static_cast<Derived&>(*this); }
const Derived& derived() const {
return static_cast<const Derived&>(*this);
}
std::size_t size() const {
return std::extent<decltype(Derived::data)>::value;
}
constexpr Derived& fill(const Scalar& v) noexcept {
std::fill(std::begin(*this), std::end(*this), Scalar(v));
return derived();
}
static constexpr Derived zero() noexcept {
Derived d;
d.fill(0);
return d;
}
constexpr Scalar squared_norm() const { return dot(*this, *this); }
constexpr Scalar norm() const { return std::sqrt(squared_norm()); }
constexpr Derived normalized() const { return *this / this->norm(); }
constexpr void normalize() { *this /= this->norm(); }
static constexpr Scalar dot(const Derived& a, const Derived& b) {
return std::inner_product(std::begin(a), std::end(a), std::begin(b),
Scalar{0});
}
static constexpr Derived lerp(const Derived& a, const Derived& b,
const Scalar& t) {
return a + (b - a) * t;
}
constexpr void operator*=(const Scalar& b) {
for (auto& e : *this)
e *= b;
}
constexpr void operator/=(const Scalar& b) {
for (auto& e : *this)
e /= b;
}
constexpr void operator+=(const Scalar& b) {
for (auto& e : *this)
e += b;
}
constexpr void operator-=(const Scalar& b) {
for (auto& e : *this)
e -= b;
}
constexpr const Derived operator*(const Scalar& b) const {
Derived r(derived());
for (auto& e : r)
e *= b;
return r;
}
constexpr const Derived operator/(const Scalar& b) const {
Derived r(derived());
for (auto& e : r)
e /= b;
return r;
}
constexpr const Derived operator+(const Scalar& b) const {
Derived r(derived());
for (auto& e : r)
e += b;
return r;
}
constexpr const Derived operator-(const Scalar& b) const {
Derived r(derived());
for (auto& e : r)
e -= b;
return r;
}
};
} // namespace paradiso
#endif