// // Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // // UnfoldShortCircuitToIf is an AST traverser to convert short-circuiting operators to if-else statements. // The results are assigned to s# temporaries, which are used by the main translator instead of // the original expression. // #include "compiler/translator/UnfoldShortCircuitToIf.h" #include "compiler/translator/IntermNode.h" #include "compiler/translator/IntermNodePatternMatcher.h" namespace { // Traverser that unfolds one short-circuiting operation at a time. class UnfoldShortCircuitTraverser : public TIntermTraverser { public: UnfoldShortCircuitTraverser(); bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitSelection(Visit visit, TIntermSelection *node) override; void nextIteration(); bool foundShortCircuit() const { return mFoundShortCircuit; } protected: // Marked to true once an operation that needs to be unfolded has been found. // After that, no more unfolding is performed on that traversal. bool mFoundShortCircuit; IntermNodePatternMatcher mPatternToUnfoldMatcher; }; UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser() : TIntermTraverser(true, false, true), mFoundShortCircuit(false), mPatternToUnfoldMatcher(IntermNodePatternMatcher::kUnfoldedShortCircuitExpression) { } bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) { if (mFoundShortCircuit) return false; if (visit != PreVisit) return true; if (!mPatternToUnfoldMatcher.match(node, getParentNode())) return true; // If our right node doesn't have side effects, we know we don't need to unfold this // expression: there will be no short-circuiting side effects to avoid // (note: unfolding doesn't depend on the left node -- it will always be evaluated) ASSERT(node->getRight()->hasSideEffects()); mFoundShortCircuit = true; switch (node->getOp()) { case EOpLogicalOr: { // "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true; // else s = y;", // and then further simplifies down to "bool s = x; if(!s) s = y;". TIntermSequence insertions; TType boolType(EbtBool, EbpUndefined, EvqTemporary); ASSERT(node->getLeft()->getType() == boolType); insertions.push_back(createTempInitDeclaration(node->getLeft())); TIntermAggregate *assignRightBlock = new TIntermAggregate(EOpSequence); ASSERT(node->getRight()->getType() == boolType); assignRightBlock->getSequence()->push_back(createTempAssignment(node->getRight())); TIntermUnary *notTempSymbol = new TIntermUnary(EOpLogicalNot, boolType); notTempSymbol->setOperand(createTempSymbol(boolType)); TIntermSelection *ifNode = new TIntermSelection(notTempSymbol, assignRightBlock, nullptr); insertions.push_back(ifNode); insertStatementsInParentBlock(insertions); queueReplacement(node, createTempSymbol(boolType), OriginalNode::IS_DROPPED); return false; } case EOpLogicalAnd: { // "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y; // else s = false;", // and then further simplifies down to "bool s = x; if(s) s = y;". TIntermSequence insertions; TType boolType(EbtBool, EbpUndefined, EvqTemporary); ASSERT(node->getLeft()->getType() == boolType); insertions.push_back(createTempInitDeclaration(node->getLeft())); TIntermAggregate *assignRightBlock = new TIntermAggregate(EOpSequence); ASSERT(node->getRight()->getType() == boolType); assignRightBlock->getSequence()->push_back(createTempAssignment(node->getRight())); TIntermSelection *ifNode = new TIntermSelection(createTempSymbol(boolType), assignRightBlock, nullptr); insertions.push_back(ifNode); insertStatementsInParentBlock(insertions); queueReplacement(node, createTempSymbol(boolType), OriginalNode::IS_DROPPED); return false; } default: UNREACHABLE(); return true; } } bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection *node) { if (mFoundShortCircuit) return false; if (visit != PreVisit) return true; if (!mPatternToUnfoldMatcher.match(node)) return true; mFoundShortCircuit = true; ASSERT(node->usesTernaryOperator()); // Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;" TIntermSequence insertions; TIntermSymbol *tempSymbol = createTempSymbol(node->getType()); TIntermAggregate *tempDeclaration = new TIntermAggregate(EOpDeclaration); tempDeclaration->getSequence()->push_back(tempSymbol); insertions.push_back(tempDeclaration); TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence); TIntermBinary *trueAssignment = createTempAssignment(node->getTrueBlock()->getAsTyped()); trueBlock->getSequence()->push_back(trueAssignment); TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence); TIntermBinary *falseAssignment = createTempAssignment(node->getFalseBlock()->getAsTyped()); falseBlock->getSequence()->push_back(falseAssignment); TIntermSelection *ifNode = new TIntermSelection(node->getCondition()->getAsTyped(), trueBlock, falseBlock); insertions.push_back(ifNode); insertStatementsInParentBlock(insertions); TIntermSymbol *ternaryResult = createTempSymbol(node->getType()); queueReplacement(node, ternaryResult, OriginalNode::IS_DROPPED); return false; } void UnfoldShortCircuitTraverser::nextIteration() { mFoundShortCircuit = false; nextTemporaryIndex(); } } // namespace void UnfoldShortCircuitToIf(TIntermNode *root, unsigned int *temporaryIndex) { UnfoldShortCircuitTraverser traverser; ASSERT(temporaryIndex != nullptr); traverser.useTemporaryIndex(temporaryIndex); // Unfold one operator at a time, and reset the traverser between iterations. do { traverser.nextIteration(); root->traverse(&traverser); if (traverser.foundShortCircuit()) traverser.updateTree(); } while (traverser.foundShortCircuit()); }