Jochen Wilhelmy
2012-Apr-12 17:14 UTC
[LLVMdev] detection of constant diagonal matrix * vector
Hi! currently instcombine does not detect constant diagonal matrix times vector, for example a.xx * [2 0] + a.yy * [0 3] can be optimized to a * [2 3] I have implemented this for float. I know that this assumes x * 0 = 0 which is not ieee compliant but i post it here in case it is interesting for someone. on my wish list there is still an option for target independent optimizations to have x * 0 = 0. -Jochen static void getIntVector(Value* value, SmallVector<int, 8>& values) { if (llvm::ConstantVector* constantVector = llvm::dyn_cast<llvm::ConstantVector>(value)) { // get components llvm::SmallVector<llvm::Constant*, 8> elements; constantVector->getVectorElements(elements); int numElements = int(elements.size()); for (int i = 0; i < numElements; ++i) { if (llvm::ConstantInt* element = llvm::dyn_cast<llvm::ConstantInt>(elements[i])) values[i] = int(element->getZExtValue()); } } } at the end of InstCombiner::visitFAdd: // check for constant diagonal matrix * vector: a.xx * [2 0] + a.yy * [0 3] --> a * [2 3] BinaryOperator* leftMul = dyn_cast<BinaryOperator>(LHS); BinaryOperator* rightMul = dyn_cast<BinaryOperator>(RHS); if (leftMul != NULL && rightMul != NULL && leftMul->getOpcode() == Instruction::FMul && rightMul->getOpcode() == Instruction::FMul) { ShuffleVectorInst* leftShuffle = dyn_cast<ShuffleVectorInst>(leftMul->getOperand(0)); ShuffleVectorInst* rightShuffle = dyn_cast<ShuffleVectorInst>(rightMul->getOperand(0)); // get multiplication constant vectors (e.g. [0 1]) ConstantVector* leftConstVector = llvm::dyn_cast<ConstantVector>(leftMul->getOperand(1)); ConstantVector* rightConstVector = llvm::dyn_cast<ConstantVector>(rightMul->getOperand(1)); if (leftShuffle != NULL && rightShuffle != NULL && leftConstVector != NULL && rightConstVector != NULL) { Value* value = leftShuffle->getOperand(0); if (value == rightShuffle->getOperand(0)) { int numElements = cast<VectorType>(I.getType())->getNumElements(); // get shuffle masks (e.g. .xx) SmallVector<int, 8> leftMask(numElements); SmallVector<int, 8> rightMask(numElements); getIntVector(leftShuffle->getOperand(2), leftMask); getIntVector(rightShuffle->getOperand(2), rightMask); SmallVector<Constant*, 8> leftConsts; SmallVector<Constant*, 8> rightConsts; leftConstVector->getVectorElements(leftConsts); rightConstVector->getVectorElements(rightConsts); if (leftConsts.size() == numElements && rightConsts.size() == numElements) { SmallVector<Constant*, 8> newShuffleMask(numElements); SmallVector<Constant*, 8> newConst(numElements); int i; bool noShuffle = true; for (i = 0; i < numElements; ++i) { // get shuffle indices int leftIndex = leftMask[i]; int rightIndex = rightMask[i]; // check if indices access the first vector if (leftIndex >= numElements && rightIndex >= numElements) break; // get values from constant vectors ConstantFP* leftConst = dyn_cast<ConstantFP>(leftConsts[i]); ConstantFP* rightConst = dyn_cast<ConstantFP>(rightConsts[i]); // check if valid if (leftConst == NULL || rightConst == NULL) break; // check if at least one is zero if (!leftConst->isZero() && !rightConst->isZero()) break; // assign dependent on constant int index = leftIndex; ConstantFP* constant = leftConst; if (!rightConst->isZero()) { index = rightIndex; constant = rightConst; } newShuffleMask[i] = Builder->getInt32(index); newConst[i] = constant; noShuffle &= index == i; } // check if we made it through if (i == numElements) { Value* newShuffle = noShuffle ? value : Builder->CreateShuffleVector( value, leftShuffle->getOperand(1), ConstantVector::get(newShuffleMask), "shuffle"); return BinaryOperator::CreateFMul(newShuffle, ConstantVector::get(newConst), "mul"); } } } } }