Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2014-2017 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "MatrixOperationBase.h" 23 : #include "core/ActionRegister.h" 24 : 25 : //+PLUMEDOC MCOLVAR TRANSPOSE 26 : /* 27 : Calculate the transpose of a matrix 28 : 29 : \par Examples 30 : 31 : */ 32 : //+ENDPLUMEDOC 33 : 34 : namespace PLMD { 35 : namespace matrixtools { 36 : 37 : class TransposeMatrix : public MatrixOperationBase { 38 : public: 39 : static void registerKeywords( Keywords& keys ); 40 : /// Constructor 41 : explicit TransposeMatrix(const ActionOptions&); 42 : /// 43 256 : unsigned getNumberOfDerivatives() override { 44 256 : return 0; 45 : } 46 : /// 47 : void prepare() override ; 48 : /// 49 : void calculate() override ; 50 : /// 51 : void apply() override ; 52 : /// 53 : double getForceOnMatrixElement( const unsigned& jrow, const unsigned& krow ) const override; 54 : }; 55 : 56 : PLUMED_REGISTER_ACTION(TransposeMatrix,"TRANSPOSE") 57 : 58 279 : void TransposeMatrix::registerKeywords( Keywords& keys ) { 59 279 : MatrixOperationBase::registerKeywords( keys ); 60 279 : keys.setValueDescription("the transpose of the input matrix"); 61 279 : } 62 : 63 162 : TransposeMatrix::TransposeMatrix(const ActionOptions& ao): 64 : Action(ao), 65 162 : MatrixOperationBase(ao) { 66 162 : if( getPntrToArgument(0)->isSymmetric() ) { 67 0 : error("input matrix is symmetric. Transposing will achieve nothing!"); 68 : } 69 : std::vector<unsigned> shape; 70 162 : if( getPntrToArgument(0)->getRank()==0 ) { 71 0 : error("transposing a scalar?"); 72 162 : } else if( getPntrToArgument(0)->getRank()==1 ) { 73 17 : shape.resize(2); 74 17 : shape[0]=1; 75 17 : shape[1]=getPntrToArgument(0)->getShape()[0]; 76 145 : } else if( getPntrToArgument(0)->getShape()[0]==1 ) { 77 61 : shape.resize(1); 78 61 : shape[0] = getPntrToArgument(0)->getShape()[1]; 79 : } else { 80 84 : shape.resize(2); 81 84 : shape[0]=getPntrToArgument(0)->getShape()[1]; 82 84 : shape[1]=getPntrToArgument(0)->getShape()[0]; 83 : } 84 162 : addValue( shape ); 85 162 : setNotPeriodic(); 86 162 : getPntrToComponent(0)->buildDataStore(); 87 162 : if( shape.size()==2 ) { 88 101 : getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 89 : } 90 162 : } 91 : 92 4821 : void TransposeMatrix::prepare() { 93 4821 : Value* myval = getPntrToComponent(0); 94 : Value* myarg = getPntrToArgument(0); 95 4821 : if( myarg->getRank()==1 ) { 96 586 : if( myval->getShape()[0]!=1 || myval->getShape()[1]!=myarg->getShape()[0] ) { 97 6 : std::vector<unsigned> shape(2); 98 6 : shape[0] = 1; 99 6 : shape[1] = myarg->getShape()[0]; 100 6 : myval->setShape( shape ); 101 6 : myval->reshapeMatrixStore( shape[1] ); 102 : } 103 4235 : } else if( myarg->getShape()[0]==1 ) { 104 2392 : if( myval->getShape()[0]!=myarg->getShape()[1] ) { 105 6 : std::vector<unsigned> shape(1); 106 6 : shape[0] = myarg->getShape()[1]; 107 6 : myval->setShape( shape ); 108 : } 109 1843 : } else if( myarg->getShape()[0]!=myval->getShape()[1] || myarg->getShape()[1]!=myval->getShape()[0] ) { 110 19 : std::vector<unsigned> shape(2); 111 19 : shape[0] = myarg->getShape()[1]; 112 19 : shape[1] = myarg->getShape()[0]; 113 19 : myval->setShape( shape ); 114 19 : myval->reshapeMatrixStore( shape[1] ); 115 : } 116 4821 : } 117 : 118 4225 : void TransposeMatrix::calculate() { 119 : // Retrieve the non-zero pairs 120 : Value* myarg=getPntrToArgument(0); 121 4225 : Value* myval=getPntrToComponent(0); 122 4225 : if( myarg->getRank()<=1 || myval->getRank()==1 ) { 123 2389 : if( myarg->getRank()<=1 && myval->getShape()[1]!=myarg->getShape()[0] ) { 124 0 : std::vector<unsigned> shape( 2 ); 125 0 : shape[0] = 1; 126 0 : shape[1] = myarg->getShape()[0]; 127 0 : myval->setShape( shape ); 128 0 : myval->reshapeMatrixStore( shape[1] ); 129 2389 : } else if( myval->getRank()==1 && myval->getShape()[0]!=myarg->getShape()[1] ) { 130 0 : std::vector<unsigned> shape( 1 ); 131 0 : shape[0] = myarg->getShape()[1]; 132 0 : myval->setShape( shape ); 133 : } 134 2389 : unsigned nv=myarg->getNumberOfValues(); 135 49015 : for(unsigned i=0; i<nv; ++i) { 136 46626 : myval->set( i, myarg->get(i) ); 137 : } 138 : } else { 139 1836 : if( myarg->getShape()[0]!=myval->getShape()[1] || myarg->getShape()[1]!=myval->getShape()[0] ) { 140 0 : std::vector<unsigned> shape( 2 ); 141 0 : shape[0] = myarg->getShape()[1]; 142 0 : shape[1] = myarg->getShape()[0]; 143 0 : myval->setShape( shape ); 144 0 : myval->reshapeMatrixStore( shape[1] ); 145 : } 146 : std::vector<double> vals; 147 : std::vector<std::pair<unsigned,unsigned> > pairs; 148 1836 : std::vector<unsigned> shape( myval->getShape() ); 149 1836 : unsigned nedge=0; 150 1836 : myarg->retrieveEdgeList( nedge, pairs, vals ); 151 2758860 : for(unsigned i=0; i<nedge; ++i) { 152 2757024 : myval->set( pairs[i].second*shape[1] + pairs[i].first, vals[i] ); 153 : } 154 : } 155 4225 : } 156 : 157 4144 : void TransposeMatrix::apply() { 158 4144 : if( doNotCalculateDerivatives() ) { 159 : return; 160 : } 161 : 162 : // Apply force on the matrix 163 1930 : if( getPntrToComponent(0)->forcesWereAdded() ) { 164 : Value* myarg=getPntrToArgument(0); 165 1930 : Value* myval=getPntrToComponent(0); 166 1930 : if( myarg->getRank()<=1 || myval->getRank()==1 ) { 167 588 : unsigned nv=myarg->getNumberOfValues(); 168 2408 : for(unsigned i=0; i<nv; ++i) { 169 1820 : myarg->addForce( i, myval->getForce(i) ); 170 : } 171 : } else { 172 1342 : MatrixOperationBase::apply(); 173 : } 174 : } 175 : } 176 : 177 4124572 : double TransposeMatrix::getForceOnMatrixElement( const unsigned& jrow, const unsigned& kcol ) const { 178 4124572 : return getConstPntrToComponent(0)->getForce(kcol*getConstPntrToComponent(0)->getShape()[1]+jrow); 179 : } 180 : 181 : 182 : } 183 : }