KrisLibrary  1.0.0
MatrixTemplate.h
1 #ifndef MATH_MATRIX_TEMPLATE_H
2 #define MATH_MATRIX_TEMPLATE_H
3 
4 #include "VectorTemplate.h"
5 #include <KrisLibrary/errors.h>
6 
7 namespace Math {
8 
16 template <class T>
18 {
19 public:
20  typedef MatrixIterator<T> MyT;
21  inline MatrixIterator() :ptr(NULL),rowPtr(NULL),istride(0),jstride(0) {}
22  inline MatrixIterator(const MyT& i) :ptr(i.ptr),rowPtr(i.rowPtr),istride(i.istride),jstride(i.jstride) {}
23  inline explicit MatrixIterator(T* _ptr,int _istride,int _jstride) :ptr(_ptr),rowPtr(_ptr),istride(_istride),jstride(_jstride) {}
24  inline T& operator*() { return *ptr; }
25  inline T* operator->() { return ptr; }
26  inline MyT& nextRow() { rowPtr+=istride; ptr=rowPtr; return *this; }
27  inline MyT& prevRow() { rowPtr-=istride; ptr=rowPtr; return *this; }
28  inline MyT& nextCol() { ptr+=jstride; return *this; }
29  inline MyT& prevCol() { ptr-=jstride; return *this; }
30  //inline MyT& operator +=(int i) { ptr+=i*stride; return *this; }
31  //inline MyT& operator -=(int i) { ptr-=i*stride; return *this; }
32  inline bool operator !=(const MyT& i) { return ptr!=i.ptr; }
33  inline bool operator ==(const MyT& i) { return ptr==i.ptr; }
34  inline bool operator < (const MyT& i) { return ptr<i.ptr; }
35  inline bool operator > (const MyT& i) { return ptr>i.ptr; }
36 
37  T *ptr, *rowPtr;
38  int istride,jstride;
39 };
40 
64 template <class T>
65 class MatrixTemplate
66 {
67 public:
68  typedef class MatrixTemplate<T> MyT;
69  typedef class MatrixIterator<T> ItT;
70  typedef class VectorTemplate<T> VectorT;
71 
73  MatrixTemplate(const MyT&);
74  MatrixTemplate(MyT&&);
75  MatrixTemplate(int m, int n);
76  MatrixTemplate(int m, int n, T initval);
77  MatrixTemplate(int m, int n, const T* vals);
78  MatrixTemplate(int m, int n, const VectorT* rows);
79  ~MatrixTemplate();
80 
81  inline T* getPointer() const { return vals; }
82  inline int getCapacity() const { return capacity; }
83  inline T* getStart() const { return vals+base; }
84  T* getRowPtr(int i) const;
85  T* getColPtr(int j) const;
86  ItT begin() const;
87  ItT end() const;
88  inline int numRows() const { return m; }
89  inline int numCols() const { return n; }
90 
91  void resize(int m, int n);
92  void resize(int m, int n, T initval);
93  void resizePersist(int m, int n);
94  void resizePersist(int m, int n, T initval);
95  void clear();
96 
97  MyT& operator = (const MyT&);
98  MyT& operator = (MyT&&);
99  bool operator == (const MyT&) const;
100  inline bool operator != (const MyT& a) const { return !operator==(a); }
101  inline const T& operator() (int,int) const;
102  inline T& operator() (int,int);
103  inline void operator += (const MyT& a) { inc(a); }
104  inline void operator -= (const MyT& a) { dec(a); }
105  inline void operator *= (T c) { inplaceMul(c); }
106  inline void operator /= (T c) { inplaceDiv(c); }
107  //NOTE: this is slow...
108  void operator *= (const MyT&);
109 
110  void copy(const MyT&);
111  template <class T2> void copy(const MatrixTemplate<T2>&);
112  void copy(const T* vals);
113  void copyColumns(const T* vals);
114  void copyRows(const VectorT* rows);
115  void copyCols(const VectorT* cols);
116  void copySubMatrix(int i, int j, const MyT&);
117  void swap(MyT&);
118  void swapCopy(MyT&);
119  void add(const MyT&, const MyT&);
120  void sub(const MyT&, const MyT&);
121  void mul(const MyT&, const MyT&);
122  void mulTransposeA(const MyT& a, const MyT& b);
123  void mulTransposeB(const MyT& a, const MyT& b);
124  void mul(const VectorT&, VectorT&) const;
125  void mulTranspose(const VectorT&, VectorT&) const;
126  void mul(const MyT&, T);
127  void div(const MyT&, T);
128  void inc(const MyT&);
129  void dec(const MyT&);
130  void madd(const MyT&, T);
131  void madd(const VectorT&, VectorT&) const;
132  void maddTranspose(const VectorT&, VectorT&) const;
133 
134  void setRef(const MyT&,int i=0,int j=0,int istride=1,int jstride=1,int m=-1,int n=-1);
135  void setRef(T* vals,int length,int offset=0,int istride=1,int jstride=1,int m=-1,int n=-1);
136  void setRefTranspose(const MyT&);
137  void set(T c);
138  void setZero();
139  void setIdentity();
140  void setNegative(const MyT&);
141  void setTranspose(const MyT&);
142  void setAdjoint(const MyT&);
143  void setInverse(const MyT&); //uses LU decomposition
144 
145  void inplaceNegative();
146  void inplaceMul(T);
147  void inplaceDiv(T);
148  void inplaceTranspose();
149  void inplaceAdjoint();
150  void inplaceInverse();
151 
152  void getSubMatrixCopy(int i, int j, MyT&) const;
153 
154  inline bool isRef() const { return !allocated; }
155  inline bool hasDims(int _m,int _n) const { return _m==m&&_n==n; }
156  inline bool isEmpty() const { return vals==NULL; }
157  inline bool isValidRow(int i) const { return i >= 0 && i < m; }
158  inline bool isValidCol(int j) const { return j >= 0 && j < n; }
159  inline bool isCompact() const { return (istride==n&&jstride==1); }
160  inline bool isRowMajor() const { return istride>jstride; }
161  inline bool isColMajor() const { return jstride>istride; }
162  inline bool isSquare() const { return m == n; }
163  bool isValid() const;
164  bool isZero(T eps=0) const;
165  bool isEqual(const MyT& a,T eps=0) const;
166  bool isIdentity() const;
167  bool isDiagonal() const;
168  bool isSymmetric() const;
169  //bool isLowerTriangular() const;
170  //bool isUpperTriangular() const;
171  bool isDiagonallyDominant() const;
172  bool isOrthogonal() const;
173  bool isInvertible() const;
174 
175  T trace() const;
176  T determinant() const;
177  T diagonalProduct() const;
178  T minElement(int*i=NULL,int*j=NULL) const;
179  T maxElement(int*i=NULL,int*j=NULL) const;
180  T minAbsElement(int*i=NULL,int*j=NULL) const;
181  T maxAbsElement(int*i=NULL,int*j=NULL) const;
182 
183  bool Read(File&);
184  bool Write(File&) const;
185 
186  void getRowRef(int i, VectorT&) const;
187  void getColRef(int j, VectorT&) const;
188  void getDiagRef(int d, VectorT&) const;
189  inline VectorT row(int i) const { VectorT a; getRowRef(i,a); return a; }
190  inline VectorT col(int j) const { VectorT a; getColRef(j,a); return a; }
191  inline VectorT diag(int d) const { VectorT a; getDiagRef(d,a); return a; }
192  inline void getRowCopy(int i, VectorT& b) const { VectorT a; getRowRef(i,a); b.copy(a); }
193  inline void getColCopy(int j, VectorT& b) const { VectorT a; getColRef(j,a); b.copy(a); }
194  inline void getDiagCopy(int d, VectorT& b) const { VectorT a; getDiagRef(d,a); b.copy(a); }
195  inline void setRow(int i, T c) { VectorT a; getRowRef(i,a); a.set(c); }
196  inline void setCol(int j, T c) { VectorT a; getColRef(j,a); a.set(c); }
197  inline void setDiag(int d, T c) { VectorT a; getDiagRef(d,a); a.set(c); }
198  inline void copyRow(int i, const VectorT& b) { VectorT a; getRowRef(i,a); a.copy(b); }
199  inline void copyCol(int j, const VectorT& b) { VectorT a; getColRef(j,a); a.copy(b); }
200  inline void copyDiag(int d, const VectorT& b) { VectorT a; getDiagRef(d,a); a.copy(b); }
201  inline void copyRow(int i, const T* b) { VectorT a; getRowRef(i,a); a.copy(b); }
202  inline void copyCol(int j, const T* b) { VectorT a; getColRef(j,a); a.copy(b); }
203  inline void copyDiag(int d, const T* b) { VectorT a; getDiagRef(d,a); a.copy(b); }
204  inline void incRow(int i,const VectorT& b) { VectorT a; getRowRef(i,a); a.inc(b); }
205  inline void incCol(int j,const VectorT& b) { VectorT a; getColRef(j,a); a.inc(b); }
206  inline void incDiag(int d,const VectorT& b) { VectorT a; getDiagRef(d,a); a.inc(b); }
207  inline void mulRow(int i,T c) { VectorT a; getRowRef(i,a); a.inplaceMul(c); }
208  inline void mulCol(int j,T c) { VectorT a; getColRef(j,a); a.inplaceMul(c); }
209  inline void mulDiag(int d,T c) { VectorT a; getDiagRef(d,a); a.inplaceMul(c); }
210  inline void maddRow(int i,const VectorT& b,T c) { VectorT a; getRowRef(i,a); a.mul(b,c); }
211  inline void maddCol(int j,const VectorT& b,T c) { VectorT a; getColRef(j,a); a.madd(b,c); }
212  inline void maddDiag(int d,const VectorT& b,T c) { VectorT a; getDiagRef(d,a); a.madd(b,c); }
213  inline T dotRow(int i,const VectorT& b) const { VectorT a; getRowRef(i,a); return a.dot(b); }
214  inline T dotCol(int j,const VectorT& b) const { VectorT a; getColRef(j,a); return a.dot(b); }
215 
216  inline void incRow(int i,const MyT& m,int im) { VectorT a; m.getRowRef(im,a); incRow(i,a); }
217  inline void incCol(int j,const MyT& m,int jm) { VectorT a; m.getColRef(jm,a); incCol(j,a); }
218  inline void maddRow(int i,const MyT& m,int im,T c) { VectorT a; m.getRowRef(im,a); maddRow(i,a,c); }
219  inline void maddCol(int j,const MyT& m,int jm,T c) { VectorT a; m.getColRef(jm,a); maddCol(j,a,c); }
220  inline T dotRow(int i,const MyT& m,int im) const { VectorT a; m.getRowRef(im,a); return dotRow(i,a); }
221  inline T dotCol(int j,const MyT& m,int jm) const { VectorT a; m.getColRef(jm,a); return dotCol(j,a); }
222 
223  void componentMul(const MyT& a,const MyT& b);
224  void componentDiv(const MyT& a,const MyT& b);
225  void componentMadd(const MyT& a,const MyT& b);
226  void inplaceComponentMul(const MyT& c);
227  void inplaceComponentDiv(const MyT& c);
228 
229 private:
230  //read only
231  T* vals;
232  int capacity;
233  bool allocated;
234 
235 public:
236  //alterable
237  int base,istride,m,jstride,n;
238 };
239 
241 template <class T>
242 inline bool IsFinite(const MatrixTemplate<T>& A)
243 {
244  for(int i=0;i<A.m;i++)
245  for(int j=0;j<A.n;j++)
246  if(!IsFinite(A(i,j))) return false;
247  return true;
248 }
249 
251 template <class T>
252 inline bool HasNaN(const MatrixTemplate<T>& A)
253 {
254  for(int i=0;i<A.m;i++)
255  for(int j=0;j<A.n;j++)
256  if(IsNaN(A(i,j))) return true;
257  return false;
258 }
259 
261 template <class T>
262 inline int HasInf(const MatrixTemplate<T>& A)
263 {
264  for(int i=0;i<A.m;i++)
265  for(int j=0;j<A.n;j++)
266  if(IsInf(A(i,j))) return IsInf(A(i,j));
267  return 0;
268 }
269 
270 
271 class Complex;
272 typedef class MatrixTemplate<float> fMatrix;
273 typedef class MatrixTemplate<double> dMatrix;
274 typedef class MatrixTemplate<Complex> cMatrix;
275 
276 template <class T>
277 std::ostream& operator << (std::ostream&, const MatrixTemplate<T>&);
278 template <class T>
279 std::istream& operator >> (std::istream&, MatrixTemplate<T>&);
280 
281 extern const char* MatrixError_IncompatibleDimensions;
282 extern const char* MatrixError_ArgIncompatibleDimensions;
283 extern const char* MatrixError_DestIncompatibleDimensions;
284 extern const char* MatrixError_SizeZero;
285 extern const char* MatrixError_NotSquare;
286 extern const char* MatrixError_NotSymmetric;
287 extern const char* MatrixError_InvalidRow;
288 extern const char* MatrixError_InvalidCol;
289 
290 template <class T>
291 inline const T& MatrixTemplate<T>::operator() (int i,int j) const
292 {
293 #ifdef _DEBUG
294  if(!isValidRow(i))
295  FatalError(MatrixError_InvalidRow);
296  if(!isValidCol(j))
297  FatalError(MatrixError_InvalidCol);
298 #endif
299  return vals[base+i*istride+j*jstride];
300 }
301 
302 template <class T>
303 inline T& MatrixTemplate<T>::operator() (int i,int j)
304 {
305 #ifdef _DEBUG
306  if(!isValidRow(i))
307  FatalError(MatrixError_InvalidRow);
308  if(!isValidCol(j))
309  FatalError(MatrixError_InvalidCol);
310 #endif
311  return vals[base+i*istride+j*jstride];
312 }
313 
314 } //namespace Math
315 
316 namespace std
317 {
318  template<class T> inline void swap(Math::MatrixTemplate<T>& a, Math::MatrixTemplate<T>& b)
319  {
320  a.swap(b);
321  }
322 } //namespace std
323 
324 
325 #endif
int IsInf(double x)
Returns +1 if x is +inf, -1 if x is -inf, and 0 otherwise.
Definition: infnan.cpp:92
Definition: rayprimitives.h:132
int HasInf(const MatrixTemplate< T > &A)
returns nonzero if any element of A is infinite
Definition: MatrixTemplate.h:262
Complex number class (x + iy).
Definition: complex.h:17
bool HasNaN(const MatrixTemplate< T > &A)
returns true if any element of A is NaN
Definition: MatrixTemplate.h:252
void copy(const T &a, T *out, int n)
Definition: arrayutils.h:34
int IsNaN(double x)
Returns nonzero if x is not-a-number (NaN)
Definition: infnan.cpp:61
A matrix over the field T.
Definition: function.h:10
An iterator through MatrixTemplate elements.
Definition: MatrixTemplate.h:17
Contains all definitions in the Math package.
Definition: WorkspaceBound.h:12
A vector over the field T.
Definition: function.h:9
A cross-platform class for reading/writing binary data.
Definition: File.h:47
int IsFinite(double x)
Returns nonzero unless x is infinite or a NaN.
Definition: infnan.cpp:75