24 #ifndef LBCRYPTO_MATH_MATRIXSTRASSEN_H 25 #define LBCRYPTO_MATH_MATRIXSTRASSEN_H 31 #include "math/matrix.h" 35 template <
class Element>
38 typedef vector<vector<Element>> data_t;
39 typedef vector<Element> lineardata_t;
40 typedef typename vector<Element>::iterator it_lineardata_t;
41 typedef std::function<Element(void)> alloc_func;
51 : data(), rows(rows), cols(cols), allocZero(allocZero) {
53 for (
auto row = data.begin(); row != data.end(); ++row) {
54 for (
size_t col = 0; col < cols; ++col) {
55 row->push_back(allocZero());
81 : data(), rows(0), cols(0), allocZero(allocZero) {}
83 void SetSize(
size_t rows,
size_t cols) {
84 if (this->rows != 0 || this->cols != 0) {
86 "You cannot SetSize on a non-empty matrix");
93 for (
auto row = data.begin(); row != data.end(); ++row) {
94 for (
size_t col = 0; col < cols; ++col) {
95 row->push_back(allocZero());
106 : data(), rows(other.rows), cols(other.cols), allocZero(other.allocZero) {
107 deepCopyData(other.data);
154 inline double Norm()
const;
175 #pragma omp parallel for 176 for (int32_t col = 0; col < result.cols; ++col) {
177 for (int32_t row = 0; row < result.rows; ++row) {
178 *result.data[row][col] = *result.data[row][col] * other;
201 if (rows != other.rows || cols != other.cols) {
205 for (
size_t i = 0; i < rows; ++i) {
206 for (
size_t j = 0; j < cols; ++j) {
207 if (data[i][j] != other.data[i][j]) {
232 return !
Equal(other);
240 const data_t&
GetData()
const {
return data; }
280 if (rows != other.rows || cols != other.cols) {
282 "Addition operands have incompatible dimensions");
285 #pragma omp parallel for 286 for (int32_t j = 0; j < cols; ++j) {
287 for (int32_t i = 0; i < rows; ++i) {
288 *result.data[i][j] += *other.data[i][j];
303 return this->
Add(other);
323 if (rows != other.rows || cols != other.cols) {
325 "Subtraction operands have incompatible dimensions");
328 #pragma omp parallel for 329 for (int32_t j = 0; j < cols; ++j) {
330 for (int32_t i = 0; i < rows; ++i) {
331 *result.data[i][j] = *data[i][j] - *other.data[i][j];
346 return this->
Sub(other);
405 inline Element&
operator()(
size_t row,
size_t col) {
return data[row][col]; }
414 inline Element
const&
operator()(
size_t row,
size_t col)
const {
415 return data[row][col];
427 for (
auto elem = this->
GetData()[row].begin();
428 elem != this->
GetData()[row].end(); ++elem) {
429 result(0, i) = **elem;
449 int nrec = 0,
int pad = -1)
const;
466 struct MatDescriptor {
475 const int DESC_SIZE = 7;
476 const int rank = 0, base = 0;
480 mutable int rowpad = 0;
482 mutable int colpad = 0;
483 alloc_func allocZero;
484 mutable char* pattern =
nullptr;
485 mutable int numAdd = 0;
486 mutable int numMult = 0;
487 mutable int numSub = 0;
488 mutable MatDescriptor desc;
489 mutable Element zeroUniquePtr = allocZero();
490 mutable int NUM_THREADS = 1;
492 void multiplyInternalCAPS(it_lineardata_t A, it_lineardata_t B,
493 it_lineardata_t C, MatDescriptor desc,
494 it_lineardata_t work)
const;
495 void strassenDFSCAPS(it_lineardata_t A, it_lineardata_t B, it_lineardata_t C,
497 it_lineardata_t workPassThrough)
const;
498 void block_multiplyCAPS(it_lineardata_t A, it_lineardata_t B,
499 it_lineardata_t C, MatDescriptor d,
500 it_lineardata_t workPassThrough)
const;
501 void LinearizeDataCAPS(lineardata_t* lineardataPtr)
const;
502 void UnlinearizeDataCAPS(lineardata_t* lineardataPtr)
const;
504 void verifyDescriptor(MatDescriptor desc);
505 long long numEntriesPerProc(MatDescriptor desc)
const;
507 void deepCopyData(data_t
const& src);
508 void getData(
const data_t& Adata,
const data_t& Bdata,
const data_t& Cdata,
509 int row,
int inner,
int col)
const;
511 void smartSubtractionCAPS(it_lineardata_t result, it_lineardata_t A,
512 it_lineardata_t B)
const;
513 void smartAdditionCAPS(it_lineardata_t result, it_lineardata_t A,
514 it_lineardata_t B)
const;
515 void addMatricesCAPS(
int numEntries, it_lineardata_t C, it_lineardata_t A,
516 it_lineardata_t B)
const;
517 void addSubMatricesCAPS(
int numEntries, it_lineardata_t T1,
518 it_lineardata_t S11, it_lineardata_t S12,
519 it_lineardata_t T2, it_lineardata_t S21,
520 it_lineardata_t S22)
const;
521 void subMatricesCAPS(
int numEntries, it_lineardata_t C, it_lineardata_t A,
522 it_lineardata_t B)
const;
523 void tripleAddMatricesCAPS(
int numEntries, it_lineardata_t T1,
524 it_lineardata_t S11, it_lineardata_t S12,
525 it_lineardata_t T2, it_lineardata_t S21,
526 it_lineardata_t S22, it_lineardata_t T3,
527 it_lineardata_t S31, it_lineardata_t S32)
const;
528 void tripleSubMatricesCAPS(
int numEntries, it_lineardata_t T1,
529 it_lineardata_t S11, it_lineardata_t S12,
530 it_lineardata_t T2, it_lineardata_t S21,
531 it_lineardata_t S22, it_lineardata_t T3,
532 it_lineardata_t S31, it_lineardata_t S32)
const;
534 void distributeFrom1ProcCAPS(MatDescriptor desc, it_lineardata_t O,
535 it_lineardata_t I)
const;
536 void collectTo1ProcCAPS(MatDescriptor desc, it_lineardata_t O,
537 it_lineardata_t I)
const;
538 void sendBlockCAPS(
int rank,
int target, it_lineardata_t O,
int bs,
539 int source, it_lineardata_t I,
int ldi)
const;
540 void receiveBlockCAPS(
int rank,
int target, it_lineardata_t O,
int bs,
541 int source, it_lineardata_t I,
int ldo)
const;
542 void distributeFrom1ProcRecCAPS(MatDescriptor desc, it_lineardata_t O,
543 it_lineardata_t I,
int ldi)
const;
544 void collectTo1ProcRecCAPS(MatDescriptor desc, it_lineardata_t O,
545 it_lineardata_t I,
int ldo)
const;
555 template <
class Element>
588 template <
class Element>
589 inline std::ostream& operator<<(std::ostream& os,
638 const shared_ptr<ILParams> params);
651 const shared_ptr<ILParams> params);
653 #endif // LBCRYPTO_MATH_MATRIXSTRASSEN_H const data_t & GetData() const
Definition: matrixstrassen.h:240
MatrixStrassen< Element > & Ones()
Definition: matrixstrassen.cpp:53
MatrixStrassen(alloc_func allocZero)
Definition: matrixstrassen.h:80
MatrixStrassen(const MatrixStrassen< Element > &other)
Definition: matrixstrassen.h:105
void Determinant(Element *result) const
Definition: matrixstrassen.cpp:180
MatrixStrassen< Element > Sub(MatrixStrassen< Element > const &other) const
Definition: matrixstrassen.h:321
MatrixStrassen< Element > operator+(MatrixStrassen< Element > const &other) const
Definition: matrixstrassen.h:301
MatrixStrassen(alloc_func allocZero, size_t rows, size_t cols)
Definition: matrixstrassen.h:50
MatrixStrassen< Element > & VStack(MatrixStrassen< Element > const &other)
Definition: matrixstrassen.cpp:279
MatrixStrassen< Element > & Identity()
Definition: matrixstrassen.cpp:73
MatrixStrassen< Element > operator-(MatrixStrassen< Element > const &other) const
Definition: matrixstrassen.h:344
alloc_func GetAllocator() const
Definition: matrixstrassen.h:261
size_t GetCols() const
Definition: matrixstrassen.h:254
MatrixStrassen< Element > & operator+=(MatrixStrassen< Element > const &other)
Definition: matrixstrassen.cpp:128
MatrixStrassen< Element > & operator-=(MatrixStrassen< Element > const &other)
Definition: matrixstrassen.cpp:144
Definition: exception.h:113
MatrixStrassen< Poly > SplitInt32IntoPolyElements(MatrixStrassen< int32_t > const &other, size_t n, const shared_ptr< ILParams > params)
Definition: matrixstrassen.cpp:544
MatrixStrassen< Element > & Fill(const Element &val)
Definition: matrixstrassen.cpp:63
Definition: matrixstrassen.h:36
Element & operator()(size_t row, size_t col)
Definition: matrixstrassen.h:405
MatrixStrassen< Element > Add(MatrixStrassen< Element > const &other) const
Definition: matrixstrassen.h:278
MatrixStrassen< Element > Transpose() const
Definition: matrixstrassen.cpp:161
size_t GetRows() const
Definition: matrixstrassen.h:247
MatrixStrassen< Poly > SplitInt32AltIntoPolyElements(MatrixStrassen< int32_t > const &other, size_t n, const shared_ptr< ILParams > params)
Definition: matrixstrassen.cpp:576
Main class for big integers represented as an array of native (primitive) unsigned integers...
Definition: ubintfxd.h:219
bool operator!=(MatrixStrassen< Element > const &other) const
Definition: matrixstrassen.h:231
void SwitchFormat()
Definition: matrixstrassen.cpp:315
Matrix< typename Element::Vector > RotateVecResult(Matrix< Element > const &inMat)
Definition: matrix-lattice-impl.cpp:72
MatrixStrassen< Element > & operator=(const MatrixStrassen< Element > &other)
Definition: matrixstrassen.cpp:44
bool operator==(MatrixStrassen< Element > const &other) const
Definition: matrixstrassen.h:221
void SetFormat(Format format)
Definition: matrixstrassen.cpp:119
Matrix< typename Element::Integer > Rotate(Matrix< Element > const &inMat)
Definition: matrix-lattice-impl.cpp:39
bool Equal(MatrixStrassen< Element > const &other) const
Definition: matrixstrassen.h:200
MatrixStrassen< Element > ScalarMult(Element const &other) const
Definition: matrixstrassen.h:173
MatrixStrassen< Element > operator*(Element const &other) const
Definition: matrixstrassen.h:190
double Norm() const
Definition: matrixstrassen.cpp:102
MatrixStrassen< Element > Mult(const MatrixStrassen< Element > &other, int nrec=0, int pad=-1) const
Definition: matrixstrassen.cpp:609
MatrixStrassen< Element > operator*(MatrixStrassen< Element > const &other) const
Definition: matrixstrassen.h:162
Element const & operator()(size_t row, size_t col) const
Definition: matrixstrassen.h:414
MatrixStrassen< Element > & HStack(MatrixStrassen< Element > const &other)
Definition: matrixstrassen.cpp:297
MatrixStrassen< Element > CofactorMatrixStrassen() const
Definition: matrixstrassen.cpp:233
MatrixStrassen< Element > GadgetVector(int32_t base=2) const
Definition: matrixstrassen.cpp:87
MatrixStrassen< Element > ExtractRow(size_t row) const
Definition: matrixstrassen.h:424
Matrix< double > Cholesky(const Matrix< int32_t > &input)
Definition: binfhecontext.h:36
Definition: exception.h:126
Matrix< int32_t > ConvertToInt32(const Matrix< BigInteger > &input, const BigInteger &modulus)
Definition: matrix-impl.cpp:197