diff options
Diffstat (limited to 'gfx/angle/src/compiler/translator/RemoveDynamicIndexing.cpp')
-rwxr-xr-x | gfx/angle/src/compiler/translator/RemoveDynamicIndexing.cpp | 513 |
1 files changed, 513 insertions, 0 deletions
diff --git a/gfx/angle/src/compiler/translator/RemoveDynamicIndexing.cpp b/gfx/angle/src/compiler/translator/RemoveDynamicIndexing.cpp new file mode 100755 index 000000000..31914dcf3 --- /dev/null +++ b/gfx/angle/src/compiler/translator/RemoveDynamicIndexing.cpp @@ -0,0 +1,513 @@ +// +// Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices, +// replacing them with calls to functions that choose which component to return or write. +// + +#include "compiler/translator/RemoveDynamicIndexing.h" + +#include "compiler/translator/InfoSink.h" +#include "compiler/translator/IntermNode.h" +#include "compiler/translator/IntermNodePatternMatcher.h" +#include "compiler/translator/SymbolTable.h" + +namespace sh +{ + +namespace +{ + +TName GetIndexFunctionName(const TType &type, bool write) +{ + TInfoSinkBase nameSink; + nameSink << "dyn_index_"; + if (write) + { + nameSink << "write_"; + } + if (type.isMatrix()) + { + nameSink << "mat" << type.getCols() << "x" << type.getRows(); + } + else + { + switch (type.getBasicType()) + { + case EbtInt: + nameSink << "ivec"; + break; + case EbtBool: + nameSink << "bvec"; + break; + case EbtUInt: + nameSink << "uvec"; + break; + case EbtFloat: + nameSink << "vec"; + break; + default: + UNREACHABLE(); + } + nameSink << type.getNominalSize(); + } + TString nameString = TFunction::mangleName(nameSink.c_str()); + TName name(nameString); + name.setInternal(true); + return name; +} + +TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier) +{ + TIntermSymbol *symbol = new TIntermSymbol(0, "base", type); + symbol->setInternal(true); + symbol->getTypePointer()->setQualifier(qualifier); + return symbol; +} + +TIntermSymbol *CreateIndexSymbol() +{ + TIntermSymbol *symbol = new TIntermSymbol(0, "index", TType(EbtInt, EbpHigh)); + symbol->setInternal(true); + symbol->getTypePointer()->setQualifier(EvqIn); + return symbol; +} + +TIntermSymbol *CreateValueSymbol(const TType &type) +{ + TIntermSymbol *symbol = new TIntermSymbol(0, "value", type); + symbol->setInternal(true); + symbol->getTypePointer()->setQualifier(EvqIn); + return symbol; +} + +TIntermConstantUnion *CreateIntConstantNode(int i) +{ + TConstantUnion *constant = new TConstantUnion(); + constant->setIConst(i); + return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh)); +} + +TIntermBinary *CreateIndexDirectBaseSymbolNode(const TType &indexedType, + const TType &fieldType, + const int index, + TQualifier baseQualifier) +{ + TIntermSymbol *baseSymbol = CreateBaseSymbol(indexedType, baseQualifier); + TIntermBinary *indexNode = + new TIntermBinary(EOpIndexDirect, baseSymbol, TIntermTyped::CreateIndexNode(index)); + return indexNode; +} + +TIntermBinary *CreateAssignValueSymbolNode(TIntermTyped *targetNode, const TType &assignedValueType) +{ + return new TIntermBinary(EOpAssign, targetNode, CreateValueSymbol(assignedValueType)); +} + +TIntermTyped *EnsureSignedInt(TIntermTyped *node) +{ + if (node->getBasicType() == EbtInt) + return node; + + TIntermAggregate *convertedNode = new TIntermAggregate(EOpConstructInt); + convertedNode->setType(TType(EbtInt)); + convertedNode->getSequence()->push_back(node); + convertedNode->setPrecisionFromChildren(); + return convertedNode; +} + +TType GetFieldType(const TType &indexedType) +{ + if (indexedType.isMatrix()) + { + TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision()); + fieldType.setPrimarySize(static_cast<unsigned char>(indexedType.getRows())); + return fieldType; + } + else + { + return TType(indexedType.getBasicType(), indexedType.getPrecision()); + } +} + +// Generate a read or write function for one field in a vector/matrix. +// Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range +// indices in other places. +// Note that indices can be either int or uint. We create only int versions of the functions, +// and convert uint indices to int at the call site. +// read function example: +// float dyn_index_vec2(in vec2 base, in int index) +// { +// switch(index) +// { +// case (0): +// return base[0]; +// case (1): +// return base[1]; +// default: +// break; +// } +// if (index < 0) +// return base[0]; +// return base[1]; +// } +// write function example: +// void dyn_index_write_vec2(inout vec2 base, in int index, in float value) +// { +// switch(index) +// { +// case (0): +// base[0] = value; +// return; +// case (1): +// base[1] = value; +// return; +// default: +// break; +// } +// if (index < 0) +// { +// base[0] = value; +// return; +// } +// base[1] = value; +// } +// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation. +TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type, bool write) +{ + ASSERT(!type.isArray()); + // Conservatively use highp here, even if the indexed type is not highp. That way the code can't + // end up using mediump version of an indexing function for a highp value, if both mediump and + // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in + // principle this code could be used with multiple backends. + type.setPrecision(EbpHigh); + + TType fieldType = GetFieldType(type); + int numCases = 0; + if (type.isMatrix()) + { + numCases = type.getCols(); + } + else + { + numCases = type.getNominalSize(); + } + + TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters); + TQualifier baseQualifier = EvqInOut; + if (!write) + baseQualifier = EvqIn; + TIntermSymbol *baseParam = CreateBaseSymbol(type, baseQualifier); + paramsNode->getSequence()->push_back(baseParam); + TIntermSymbol *indexParam = CreateIndexSymbol(); + paramsNode->getSequence()->push_back(indexParam); + if (write) + { + TIntermSymbol *valueParam = CreateValueSymbol(fieldType); + paramsNode->getSequence()->push_back(valueParam); + } + + TIntermBlock *statementList = new TIntermBlock(); + for (int i = 0; i < numCases; ++i) + { + TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i)); + statementList->getSequence()->push_back(caseNode); + + TIntermBinary *indexNode = + CreateIndexDirectBaseSymbolNode(type, fieldType, i, baseQualifier); + if (write) + { + TIntermBinary *assignNode = CreateAssignValueSymbolNode(indexNode, fieldType); + statementList->getSequence()->push_back(assignNode); + TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr); + statementList->getSequence()->push_back(returnNode); + } + else + { + TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode); + statementList->getSequence()->push_back(returnNode); + } + } + + // Default case + TIntermCase *defaultNode = new TIntermCase(nullptr); + statementList->getSequence()->push_back(defaultNode); + TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr); + statementList->getSequence()->push_back(breakNode); + + TIntermSwitch *switchNode = new TIntermSwitch(CreateIndexSymbol(), statementList); + + TIntermBlock *bodyNode = new TIntermBlock(); + bodyNode->getSequence()->push_back(switchNode); + + TIntermBinary *cond = + new TIntermBinary(EOpLessThan, CreateIndexSymbol(), CreateIntConstantNode(0)); + cond->setType(TType(EbtBool, EbpUndefined)); + + // Two blocks: one accesses (either reads or writes) the first element and returns, + // the other accesses the last element. + TIntermBlock *useFirstBlock = new TIntermBlock(); + TIntermBlock *useLastBlock = new TIntermBlock(); + TIntermBinary *indexFirstNode = + CreateIndexDirectBaseSymbolNode(type, fieldType, 0, baseQualifier); + TIntermBinary *indexLastNode = + CreateIndexDirectBaseSymbolNode(type, fieldType, numCases - 1, baseQualifier); + if (write) + { + TIntermBinary *assignFirstNode = CreateAssignValueSymbolNode(indexFirstNode, fieldType); + useFirstBlock->getSequence()->push_back(assignFirstNode); + TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr); + useFirstBlock->getSequence()->push_back(returnNode); + + TIntermBinary *assignLastNode = CreateAssignValueSymbolNode(indexLastNode, fieldType); + useLastBlock->getSequence()->push_back(assignLastNode); + } + else + { + TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode); + useFirstBlock->getSequence()->push_back(returnFirstNode); + + TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode); + useLastBlock->getSequence()->push_back(returnLastNode); + } + TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr); + bodyNode->getSequence()->push_back(ifNode); + bodyNode->getSequence()->push_back(useLastBlock); + + TIntermFunctionDefinition *indexingFunction = nullptr; + if (write) + { + indexingFunction = new TIntermFunctionDefinition(TType(EbtVoid), paramsNode, bodyNode); + } + else + { + indexingFunction = new TIntermFunctionDefinition(fieldType, paramsNode, bodyNode); + } + indexingFunction->getFunctionSymbolInfo()->setNameObj(GetIndexFunctionName(type, write)); + return indexingFunction; +} + +class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser +{ + public: + RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable, int shaderVersion); + + bool visitBinary(Visit visit, TIntermBinary *node) override; + + void insertHelperDefinitions(TIntermNode *root); + + void nextIteration(); + + bool usedTreeInsertion() const { return mUsedTreeInsertion; } + + protected: + // Sets of types that are indexed. Note that these can not store multiple variants + // of the same type with different precisions - only one precision gets stored. + std::set<TType> mIndexedVecAndMatrixTypes; + std::set<TType> mWrittenVecAndMatrixTypes; + + bool mUsedTreeInsertion; + + // When true, the traverser will remove side effects from any indexing expression. + // This is done so that in code like + // V[j++][i]++. + // where V is an array of vectors, j++ will only be evaluated once. + bool mRemoveIndexSideEffectsInSubtree; +}; + +RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable, + int shaderVersion) + : TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion), + mUsedTreeInsertion(false), + mRemoveIndexSideEffectsInSubtree(false) +{ +} + +void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root) +{ + TIntermBlock *rootBlock = root->getAsBlock(); + ASSERT(rootBlock != nullptr); + TIntermSequence insertions; + for (TType type : mIndexedVecAndMatrixTypes) + { + insertions.push_back(GetIndexFunctionDefinition(type, false)); + } + for (TType type : mWrittenVecAndMatrixTypes) + { + insertions.push_back(GetIndexFunctionDefinition(type, true)); + } + mInsertions.push_back(NodeInsertMultipleEntry(rootBlock, 0, insertions, TIntermSequence())); +} + +// Create a call to dyn_index_*() based on an indirect indexing op node +TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node, + TIntermTyped *indexedNode, + TIntermTyped *index) +{ + ASSERT(node->getOp() == EOpIndexIndirect); + TIntermAggregate *indexingCall = new TIntermAggregate(EOpFunctionCall); + indexingCall->setLine(node->getLine()); + indexingCall->setUserDefined(); + indexingCall->getFunctionSymbolInfo()->setNameObj( + GetIndexFunctionName(indexedNode->getType(), false)); + indexingCall->getSequence()->push_back(indexedNode); + indexingCall->getSequence()->push_back(index); + + TType fieldType = GetFieldType(indexedNode->getType()); + indexingCall->setType(fieldType); + return indexingCall; +} + +TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node, + TIntermTyped *index, + TIntermTyped *writtenValue) +{ + // Deep copy the left node so that two pointers to the same node don't end up in the tree. + TIntermNode *leftCopy = node->getLeft()->deepCopy(); + ASSERT(leftCopy != nullptr && leftCopy->getAsTyped() != nullptr); + TIntermAggregate *indexedWriteCall = + CreateIndexFunctionCall(node, leftCopy->getAsTyped(), index); + indexedWriteCall->getFunctionSymbolInfo()->setNameObj( + GetIndexFunctionName(node->getLeft()->getType(), true)); + indexedWriteCall->setType(TType(EbtVoid)); + indexedWriteCall->getSequence()->push_back(writtenValue); + return indexedWriteCall; +} + +bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node) +{ + if (mUsedTreeInsertion) + return false; + + if (node->getOp() == EOpIndexIndirect) + { + if (mRemoveIndexSideEffectsInSubtree) + { + ASSERT(node->getRight()->hasSideEffects()); + // In case we're just removing index side effects, convert + // v_expr[index_expr] + // to this: + // int s0 = index_expr; v_expr[s0]; + // Now v_expr[s0] can be safely executed several times without unintended side effects. + + // Init the temp variable holding the index + TIntermDeclaration *initIndex = createTempInitDeclaration(node->getRight()); + insertStatementInParentBlock(initIndex); + mUsedTreeInsertion = true; + + // Replace the index with the temp variable + TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType()); + queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED); + } + else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node)) + { + bool write = isLValueRequiredHere(); + +#if defined(ANGLE_ENABLE_ASSERTS) + // Make sure that IntermNodePatternMatcher is consistent with the slightly differently + // implemented checks in this traverser. + IntermNodePatternMatcher matcher( + IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue); + ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write); +#endif + + TType type = node->getLeft()->getType(); + mIndexedVecAndMatrixTypes.insert(type); + + if (write) + { + // Convert: + // v_expr[index_expr]++; + // to this: + // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++; + // dyn_index_write(v_expr, s0, s1); + // This works even if index_expr has some side effects. + if (node->getLeft()->hasSideEffects()) + { + // If v_expr has side effects, those need to be removed before proceeding. + // Otherwise the side effects of v_expr would be evaluated twice. + // The only case where an l-value can have side effects is when it is + // indexing. For example, it can be V[j++] where V is an array of vectors. + mRemoveIndexSideEffectsInSubtree = true; + return true; + } + // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value + // only writes it and doesn't need the previous value. http://anglebug.com/1116 + + mWrittenVecAndMatrixTypes.insert(type); + TType fieldType = GetFieldType(type); + + TIntermSequence insertionsBefore; + TIntermSequence insertionsAfter; + + // Store the index in a temporary signed int variable. + TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight()); + TIntermDeclaration *initIndex = createTempInitDeclaration(indexInitializer); + initIndex->setLine(node->getLine()); + insertionsBefore.push_back(initIndex); + + TIntermAggregate *indexingCall = CreateIndexFunctionCall( + node, node->getLeft(), createTempSymbol(indexInitializer->getType())); + + // Create a node for referring to the index after the nextTemporaryIndex() call + // below. + TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType()); + + nextTemporaryIndex(); // From now on, creating temporary symbols that refer to the + // field value. + insertionsBefore.push_back(createTempInitDeclaration(indexingCall)); + + TIntermAggregate *indexedWriteCall = + CreateIndexedWriteFunctionCall(node, tempIndex, createTempSymbol(fieldType)); + insertionsAfter.push_back(indexedWriteCall); + insertStatementsInParentBlock(insertionsBefore, insertionsAfter); + queueReplacement(node, createTempSymbol(fieldType), OriginalNode::IS_DROPPED); + mUsedTreeInsertion = true; + } + else + { + // The indexed value is not being written, so we can simply convert + // v_expr[index_expr] + // into + // dyn_index(v_expr, index_expr) + // If the index_expr is unsigned, we'll convert it to signed. + ASSERT(!mRemoveIndexSideEffectsInSubtree); + TIntermAggregate *indexingCall = CreateIndexFunctionCall( + node, node->getLeft(), EnsureSignedInt(node->getRight())); + queueReplacement(node, indexingCall, OriginalNode::IS_DROPPED); + } + } + } + return !mUsedTreeInsertion; +} + +void RemoveDynamicIndexingTraverser::nextIteration() +{ + mUsedTreeInsertion = false; + mRemoveIndexSideEffectsInSubtree = false; + nextTemporaryIndex(); +} + +} // namespace + +void RemoveDynamicIndexing(TIntermNode *root, + unsigned int *temporaryIndex, + const TSymbolTable &symbolTable, + int shaderVersion) +{ + RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion); + ASSERT(temporaryIndex != nullptr); + traverser.useTemporaryIndex(temporaryIndex); + do + { + traverser.nextIteration(); + root->traverse(&traverser); + traverser.updateTree(); + } while (traverser.usedTreeInsertion()); + traverser.insertHelperDefinitions(root); + traverser.updateTree(); +} + +} // namespace sh |