diff options
Diffstat (limited to 'ipc/ipdl/ipdl')
-rw-r--r-- | ipc/ipdl/ipdl/__init__.py | 77 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/ast.py | 454 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/builtin.py | 59 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/cgen.py | 101 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/cxx/__init__.py | 6 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/cxx/ast.py | 809 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/cxx/cgen.py | 520 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/lower.py | 4822 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/parser.py | 807 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/type.py | 2200 |
10 files changed, 9855 insertions, 0 deletions
diff --git a/ipc/ipdl/ipdl/__init__.py b/ipc/ipdl/ipdl/__init__.py new file mode 100644 index 000000000..d2d883f86 --- /dev/null +++ b/ipc/ipdl/ipdl/__init__.py @@ -0,0 +1,77 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +__all__ = [ 'gencxx', 'genipdl', 'parse', 'typecheck', 'writeifmodified' ] + +import os, sys +from cStringIO import StringIO + +from ipdl.cgen import IPDLCodeGen +from ipdl.lower import LowerToCxx, msgenums +from ipdl.parser import Parser +from ipdl.type import TypeCheck + +from ipdl.cxx.cgen import CxxCodeGen + + +def parse(specstring, filename='/stdin', includedirs=[ ], errout=sys.stderr): + '''Return an IPDL AST if parsing was successful. Print errors to |errout| + if it is not.''' + # The file type and name are later enforced by the type checker. + # This is just a hint to the parser. + prefix, ext = os.path.splitext(filename) + name = os.path.basename(prefix) + if ext == '.ipdlh': + type = 'header' + else: + type = 'protocol' + return Parser(type, name).parse(specstring, os.path.abspath(filename), includedirs, errout) + + +def typecheck(ast, errout=sys.stderr): + '''Return True iff |ast| is well typed. Print errors to |errout| if + it is not.''' + return TypeCheck().check(ast, errout) + + +def gencxx(ipdlfilename, ast, outheadersdir, outcppdir): + headers, cpps = LowerToCxx().lower(ast) + + def resolveHeader(hdr): + return [ + hdr, + os.path.join( + outheadersdir, + *([ns.name for ns in ast.namespaces] + [hdr.name])) + ] + def resolveCpp(cpp): + return [ cpp, os.path.join(outcppdir, cpp.name) ] + + for ast, filename in ([ resolveHeader(hdr) for hdr in headers ] + + [ resolveCpp(cpp) for cpp in cpps ]): + tempfile = StringIO() + CxxCodeGen(tempfile).cgen(ast) + writeifmodified(tempfile.getvalue(), filename) + + +def genipdl(ast, outdir): + return IPDLCodeGen().cgen(ast) + + +def genmsgenum(ast): + return msgenums(ast.protocol, pretty=True) + +def writeifmodified(contents, file): + dir = os.path.dirname(file) + os.path.exists(dir) or os.makedirs(dir) + + oldcontents = None + if os.path.exists(file): + fd = open(file, 'rb') + oldcontents = fd.read() + fd.close() + if oldcontents != contents: + fd = open(file, 'wb') + fd.write(contents) + fd.close() diff --git a/ipc/ipdl/ipdl/ast.py b/ipc/ipdl/ipdl/ast.py new file mode 100644 index 000000000..a8bd1e41f --- /dev/null +++ b/ipc/ipdl/ipdl/ast.py @@ -0,0 +1,454 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import sys + +NOT_NESTED = 1 +INSIDE_SYNC_NESTED = 2 +INSIDE_CPOW_NESTED = 3 + +NORMAL_PRIORITY = 1 +HIGH_PRIORITY = 2 + +class Visitor: + def defaultVisit(self, node): + raise Exception, "INTERNAL ERROR: no visitor for node type `%s'"% ( + node.__class__.__name__) + + def visitTranslationUnit(self, tu): + for cxxInc in tu.cxxIncludes: + cxxInc.accept(self) + for inc in tu.includes: + inc.accept(self) + for su in tu.structsAndUnions: + su.accept(self) + for using in tu.builtinUsing: + using.accept(self) + for using in tu.using: + using.accept(self) + if tu.protocol: + tu.protocol.accept(self) + + + def visitCxxInclude(self, inc): + pass + + def visitInclude(self, inc): + # Note: we don't visit the child AST here, because that needs delicate + # and pass-specific handling + pass + + def visitStructDecl(self, struct): + for f in struct.fields: + f.accept(self) + + def visitStructField(self, field): + field.typespec.accept(self) + + def visitUnionDecl(self, union): + for t in union.components: + t.accept(self) + + def visitUsingStmt(self, using): + pass + + def visitProtocol(self, p): + for namespace in p.namespaces: + namespace.accept(self) + for spawns in p.spawnsStmts: + spawns.accept(self) + for bridges in p.bridgesStmts: + bridges.accept(self) + for opens in p.opensStmts: + opens.accept(self) + for mgr in p.managers: + mgr.accept(self) + for managed in p.managesStmts: + managed.accept(self) + for msgDecl in p.messageDecls: + msgDecl.accept(self) + for transitionStmt in p.transitionStmts: + transitionStmt.accept(self) + + def visitNamespace(self, ns): + pass + + def visitSpawnsStmt(self, spawns): + pass + + def visitBridgesStmt(self, bridges): + pass + + def visitOpensStmt(self, opens): + pass + + def visitManager(self, mgr): + pass + + def visitManagesStmt(self, mgs): + pass + + def visitMessageDecl(self, md): + for inParam in md.inParams: + inParam.accept(self) + for outParam in md.outParams: + outParam.accept(self) + + def visitTransitionStmt(self, ts): + ts.state.accept(self) + for trans in ts.transitions: + trans.accept(self) + + def visitTransition(self, t): + for toState in t.toStates: + toState.accept(self) + + def visitState(self, s): + pass + + def visitParam(self, decl): + pass + + def visitTypeSpec(self, ts): + pass + + def visitDecl(self, d): + pass + +class Loc: + def __init__(self, filename='<??>', lineno=0): + assert filename + self.filename = filename + self.lineno = lineno + def __repr__(self): + return '%r:%r'% (self.filename, self.lineno) + def __str__(self): + return '%s:%s'% (self.filename, self.lineno) + +Loc.NONE = Loc(filename='<??>', lineno=0) + +class _struct: + pass + +class Node: + def __init__(self, loc=Loc.NONE): + self.loc = loc + + def accept(self, visitor): + visit = getattr(visitor, 'visit'+ self.__class__.__name__, None) + if visit is None: + return getattr(visitor, 'defaultVisit')(self) + return visit(self) + + def addAttrs(self, attrsName): + if not hasattr(self, attrsName): + setattr(self, attrsName, _struct()) + + +class NamespacedNode(Node): + def __init__(self, loc=Loc.NONE, name=None): + Node.__init__(self, loc) + self.name = name + self.namespaces = [ ] + + def addOuterNamespace(self, namespace): + self.namespaces.insert(0, namespace) + + def qname(self): + return QualifiedId(self.loc, self.name, + [ ns.name for ns in self.namespaces ]) + +class TranslationUnit(NamespacedNode): + def __init__(self, type, name): + NamespacedNode.__init__(self, name=name) + self.filetype = type + self.filename = None + self.cxxIncludes = [ ] + self.includes = [ ] + self.builtinUsing = [ ] + self.using = [ ] + self.structsAndUnions = [ ] + self.protocol = None + + def addCxxInclude(self, cxxInclude): self.cxxIncludes.append(cxxInclude) + def addInclude(self, inc): self.includes.append(inc) + def addStructDecl(self, struct): self.structsAndUnions.append(struct) + def addUnionDecl(self, union): self.structsAndUnions.append(union) + def addUsingStmt(self, using): self.using.append(using) + + def setProtocol(self, protocol): self.protocol = protocol + +class CxxInclude(Node): + def __init__(self, loc, cxxFile): + Node.__init__(self, loc) + self.file = cxxFile + +class Include(Node): + def __init__(self, loc, type, name): + Node.__init__(self, loc) + suffix = 'ipdl' + if type == 'header': + suffix += 'h' + self.file = "%s.%s" % (name, suffix) + +class UsingStmt(Node): + def __init__(self, loc, cxxTypeSpec, cxxHeader=None, kind=None): + Node.__init__(self, loc) + assert not isinstance(cxxTypeSpec, str) + assert cxxHeader is None or isinstance(cxxHeader, str); + assert kind is None or kind == 'class' or kind == 'struct' + self.type = cxxTypeSpec + self.header = cxxHeader + self.kind = kind + def canBeForwardDeclared(self): + return self.isClass() or self.isStruct() + def isClass(self): + return self.kind == 'class' + def isStruct(self): + return self.kind == 'struct' + +# "singletons" +class PrettyPrinted: + @classmethod + def __hash__(cls): return hash(cls.pretty) + @classmethod + def __str__(cls): return cls.pretty + +class ASYNC(PrettyPrinted): + pretty = 'async' +class INTR(PrettyPrinted): + pretty = 'intr' +class SYNC(PrettyPrinted): + pretty = 'sync' + +class INOUT(PrettyPrinted): + pretty = 'inout' +class IN(PrettyPrinted): + pretty = 'in' +class OUT(PrettyPrinted): + pretty = 'out' + + +class Namespace(Node): + def __init__(self, loc, namespace): + Node.__init__(self, loc) + self.name = namespace + +class Protocol(NamespacedNode): + def __init__(self, loc): + NamespacedNode.__init__(self, loc) + self.sendSemantics = ASYNC + self.nested = NOT_NESTED + self.spawnsStmts = [ ] + self.bridgesStmts = [ ] + self.opensStmts = [ ] + self.managers = [ ] + self.managesStmts = [ ] + self.messageDecls = [ ] + self.transitionStmts = [ ] + self.startStates = [ ] + +class StructField(Node): + def __init__(self, loc, type, name): + Node.__init__(self, loc) + self.typespec = type + self.name = name + +class StructDecl(NamespacedNode): + def __init__(self, loc, name, fields): + NamespacedNode.__init__(self, loc, name) + self.fields = fields + +class UnionDecl(NamespacedNode): + def __init__(self, loc, name, components): + NamespacedNode.__init__(self, loc, name) + self.components = components + +class SpawnsStmt(Node): + def __init__(self, loc, side, proto, spawnedAs): + Node.__init__(self, loc) + self.side = side + self.proto = proto + self.spawnedAs = spawnedAs + +class BridgesStmt(Node): + def __init__(self, loc, parentSide, childSide): + Node.__init__(self, loc) + self.parentSide = parentSide + self.childSide = childSide + +class OpensStmt(Node): + def __init__(self, loc, side, proto): + Node.__init__(self, loc) + self.side = side + self.proto = proto + +class Manager(Node): + def __init__(self, loc, managerName): + Node.__init__(self, loc) + self.name = managerName + +class ManagesStmt(Node): + def __init__(self, loc, managedName): + Node.__init__(self, loc) + self.name = managedName + +class MessageDecl(Node): + def __init__(self, loc): + Node.__init__(self, loc) + self.name = None + self.sendSemantics = ASYNC + self.nested = NOT_NESTED + self.prio = NORMAL_PRIORITY + self.direction = None + self.inParams = [ ] + self.outParams = [ ] + self.compress = '' + self.verify = '' + + def addInParams(self, inParamsList): + self.inParams += inParamsList + + def addOutParams(self, outParamsList): + self.outParams += outParamsList + + def addModifiers(self, modifiers): + for modifier in modifiers: + if modifier.startswith('compress'): + self.compress = modifier + elif modifier == 'verify': + self.verify = modifier + elif modifier != '': + raise Exception, "Unexpected message modifier `%s'"% modifier + +class Transition(Node): + def __init__(self, loc, trigger, msg, toStates): + Node.__init__(self, loc) + self.trigger = trigger + self.msg = msg + self.toStates = toStates + + def __cmp__(self, o): + c = cmp(self.msg, o.msg) + if c: return c + c = cmp(self.trigger, o.trigger) + if c: return c + + def __hash__(self): return hash(str(self)) + def __str__(self): return '%s %s'% (self.trigger, self.msg) + + @staticmethod + def nameToTrigger(name): + return { 'send': SEND, 'recv': RECV, 'call': CALL, 'answer': ANSWER }[name] + +Transition.NULL = Transition(Loc.NONE, None, None, [ ]) + +class TransitionStmt(Node): + def __init__(self, loc, state, transitions): + Node.__init__(self, loc) + self.state = state + self.transitions = transitions + + @staticmethod + def makeNullStmt(state): + return TransitionStmt(Loc.NONE, state, [ Transition.NULL ]) + +class SEND: + pretty = 'send' + @classmethod + def __hash__(cls): return hash(cls.pretty) + @classmethod + def direction(cls): return OUT +class RECV: + pretty = 'recv' + @classmethod + def __hash__(cls): return hash(cls.pretty) + @classmethod + def direction(cls): return IN +class CALL: + pretty = 'call' + @classmethod + def __hash__(cls): return hash(cls.pretty) + @classmethod + def direction(cls): return OUT +class ANSWER: + pretty = 'answer' + @classmethod + def __hash__(cls): return hash(cls.pretty) + @classmethod + def direction(cls): return IN + +class State(Node): + def __init__(self, loc, name, start=False): + Node.__init__(self, loc) + self.name = name + self.start = start + def __eq__(self, o): + return (isinstance(o, State) + and o.name == self.name + and o.start == self.start) + def __hash__(self): + return hash(repr(self)) + def __ne__(self, o): + return not (self == o) + def __repr__(self): return '<State %r start=%r>'% (self.name, self.start) + def __str__(self): return '<State %s start=%s>'% (self.name, self.start) + +State.ANY = State(Loc.NONE, '[any]', start=True) +State.DEAD = State(Loc.NONE, '[dead]', start=False) +State.DYING = State(Loc.NONE, '[dying]', start=False) + +class Param(Node): + def __init__(self, loc, typespec, name): + Node.__init__(self, loc) + self.name = name + self.typespec = typespec + +class TypeSpec(Node): + def __init__(self, loc, spec, state=None, array=0, nullable=0, + myChmod=None, otherChmod=None): + Node.__init__(self, loc) + self.spec = spec # QualifiedId + self.state = state # None or State + self.array = array # bool + self.nullable = nullable # bool + self.myChmod = myChmod # None or string + self.otherChmod = otherChmod # None or string + + def basename(self): + return self.spec.baseid + + def isActor(self): + return self.state is not None + + def __str__(self): return str(self.spec) + +class QualifiedId: # FIXME inherit from node? + def __init__(self, loc, baseid, quals=[ ]): + assert isinstance(baseid, str) + for qual in quals: assert isinstance(qual, str) + + self.loc = loc + self.baseid = baseid + self.quals = quals + + def qualify(self, id): + self.quals.append(self.baseid) + self.baseid = id + + def __str__(self): + if 0 == len(self.quals): + return self.baseid + return '::'.join(self.quals) +'::'+ self.baseid + +# added by type checking passes +class Decl(Node): + def __init__(self, loc): + Node.__init__(self, loc) + self.progname = None # what the programmer typed, if relevant + self.shortname = None # shortest way to refer to this decl + self.fullname = None # full way to refer to this decl + self.loc = loc + self.type = None + self.scope = None diff --git a/ipc/ipdl/ipdl/builtin.py b/ipc/ipdl/ipdl/builtin.py new file mode 100644 index 000000000..3a49b5c7e --- /dev/null +++ b/ipc/ipdl/ipdl/builtin.py @@ -0,0 +1,59 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +# WARNING: the syntax of the builtin types is not checked, so please +# don't add something syntactically invalid. It will not be fun to +# track down the bug. + +Types = ( + # C types + 'bool', + 'char', + 'short', + 'int', + 'long', + 'float', + 'double', + + # stdint types + 'int8_t', + 'uint8_t', + 'int16_t', + 'uint16_t', + 'int32_t', + 'uint32_t', + 'int64_t', + 'uint64_t', + 'intptr_t', + 'uintptr_t', + + # stddef types + 'size_t', + 'ssize_t', + + # Mozilla types: "less" standard things we know how serialize/deserialize + 'nsresult', + 'nsString', + 'nsCString', + 'mozilla::ipc::Shmem', + 'mozilla::ipc::FileDescriptor' +) + + +HeaderIncludes = ( + 'mozilla/Attributes.h', + 'IPCMessageStart.h', + 'ipc/IPCMessageUtils.h', + 'mozilla/RefPtr.h', + 'nsStringGlue.h', + 'nsTArray.h', + 'mozilla/ipc/ProtocolUtils.h', + 'nsTHashtable.h', + 'mozilla/OperatorNewExtensions.h', +) + +CppIncludes = ( + 'nsIFile.h', + 'GeckoProfiler.h', +) diff --git a/ipc/ipdl/ipdl/cgen.py b/ipc/ipdl/ipdl/cgen.py new file mode 100644 index 000000000..fd8951c74 --- /dev/null +++ b/ipc/ipdl/ipdl/cgen.py @@ -0,0 +1,101 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import os, sys + +from ipdl.ast import Visitor +from ipdl.ast import IN, OUT, INOUT, ASYNC, SYNC, INTR + +class CodePrinter: + def __init__(self, outf=sys.stdout, indentCols=4): + self.outf = outf + self.col = 0 + self.indentCols = indentCols + + def write(self, str): + self.outf.write(str) + + def printdent(self, str=''): + self.write((' '* self.col) + str) + + def println(self, str=''): + self.write(str +'\n') + + def printdentln(self, str): + self.write((' '* self.col) + str +'\n') + + def indent(self): self.col += self.indentCols + def dedent(self): self.col -= self.indentCols + + +##----------------------------------------------------------------------------- +class IPDLCodeGen(CodePrinter, Visitor): + '''Spits back out equivalent IPDL to the code that generated this. +Also known as pretty-printing.''' + + def __init__(self, outf=sys.stdout, indentCols=4, printed=set()): + CodePrinter.__init__(self, outf, indentCols) + self.printed = printed + + def visitTranslationUnit(self, tu): + self.printed.add(tu.filename) + self.println('//\n// Automatically generated by ipdlc\n//') + CodeGen.visitTranslationUnit(self, tu) + + def visitCxxInclude(self, inc): + self.println('include "'+ inc.file +'";') + + def visitProtocolInclude(self, inc): + self.println('include protocol "'+ inc.file +'";') + if inc.tu.filename not in self.printed: + self.println('/* Included file:') + IPDLCodeGen(outf=self.outf, indentCols=self.indentCols, + printed=self.printed).visitTranslationUnit(inc.tu) + + self.println('*/') + + def visitProtocol(self, p): + self.println() + for namespace in p.namespaces: namespace.accept(self) + + self.println('%s protocol %s\n{'% (p.sendSemantics[0], p.name)) + self.indent() + + for mgs in p.managesStmts: + mgs.accept(self) + if len(p.managesStmts): self.println() + + for msgDecl in p.messageDecls: msgDecl.accept(self) + self.println() + + for transStmt in p.transitionStmts: transStmt.accept(self) + + self.dedent() + self.println('}') + self.write('}\n'* len(p.namespaces)) + + def visitManagerStmt(self, mgr): + self.printdentln('manager '+ mgr.name +';') + + def visitManagesStmt(self, mgs): + self.printdentln('manages '+ mgs.name +';') + + def visitMessageDecl(self, msg): + self.printdent('%s %s %s('% (msg.sendSemantics[0], msg.direction[0], msg.name)) + for i, inp in enumerate(msg.inParams): + inp.accept(self) + if i != (len(msg.inParams) - 1): self.write(', ') + self.write(')') + if 0 == len(msg.outParams): + self.println(';') + return + + self.println() + self.indent() + self.printdent('returns (') + for i, outp in enumerate(msg.outParams): + outp.accept(self) + if i != (len(msg.outParams) - 1): self.write(', ') + self.println(');') + self.dedent() diff --git a/ipc/ipdl/ipdl/cxx/__init__.py b/ipc/ipdl/ipdl/cxx/__init__.py new file mode 100644 index 000000000..f43d5b0db --- /dev/null +++ b/ipc/ipdl/ipdl/cxx/__init__.py @@ -0,0 +1,6 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import ipdl.cxx.ast +import ipdl.cxx.cgen diff --git a/ipc/ipdl/ipdl/cxx/ast.py b/ipc/ipdl/ipdl/cxx/ast.py new file mode 100644 index 000000000..18c2b3f1d --- /dev/null +++ b/ipc/ipdl/ipdl/cxx/ast.py @@ -0,0 +1,809 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import copy, sys + +class Visitor: + def defaultVisit(self, node): + raise Exception, "INTERNAL ERROR: no visitor for node type `%s'"% ( + node.__class__.__name__) + + def visitWhitespace(self, ws): + pass + + def visitFile(self, f): + for thing in f.stuff: + thing.accept(self) + + def visitCppDirective(self, ppd): + pass + + def visitBlock(self, block): + for stmt in block.stmts: + stmt.accept(self) + + def visitNamespace(self, ns): + self.visitBlock(ns) + + def visitType(self, type): + pass + + def visitTypeArray(self, ta): + ta.basetype.accept(self) + ta.nmemb.accept(self) + + def visitTypeEnum(self, enum): + pass + + def visitTypeUnion(self, union): + for t, name in union.components: + t.accept(self) + + def visitTypedef(self, tdef): + tdef.fromtype.accept(self) + + def visitUsing(self, us): + us.type.accept(self) + + def visitForwardDecl(self, fd): + pass + + def visitDecl(self, decl): + decl.type.accept(self) + + def visitParam(self, param): + self.visitDecl(param) + if param.default is not None: + param.default.accept(self) + + def visitClass(self, cls): + for inherit in cls.inherits: + inherit.accept(self) + self.visitBlock(cls) + + def visitInherit(self, inh): + pass + + def visitFriendClassDecl(self, fcd): + pass + + def visitMethodDecl(self, meth): + for param in meth.params: + param.accept(self) + if meth.ret is not None: + meth.ret.accept(self) + if meth.typeop is not None: + meth.typeop.accept(self) + if meth.T is not None: + meth.T.accept(self) + + def visitMethodDefn(self, meth): + meth.decl.accept(self) + self.visitBlock(meth) + + def visitFunctionDecl(self, fun): + self.visitMethodDecl(fun) + + def visitFunctionDefn(self, fd): + self.visitMethodDefn(fd) + + def visitConstructorDecl(self, ctor): + self.visitMethodDecl(ctor) + + def visitConstructorDefn(self, cd): + cd.decl.accept(self) + for init in cd.memberinits: + init.accept(self) + self.visitBlock(cd) + + def visitDestructorDecl(self, dtor): + self.visitMethodDecl(dtor) + + def visitDestructorDefn(self, dd): + dd.decl.accept(self) + self.visitBlock(dd) + + def visitExprLiteral(self, l): + pass + + def visitExprVar(self, v): + pass + + def visitExprPrefixUnop(self, e): + e.expr.accept(self) + + def visitExprBinary(self, e): + e.left.accept(self) + e.right.accept(self) + + def visitExprConditional(self, c): + c.cond.accept(self) + c.ife.accept(self) + c.elsee.accept(self) + + def visitExprAddrOf(self, eao): + self.visitExprPrefixUnop(eao) + + def visitExprDeref(self, ed): + self.visitExprPrefixUnop(ed) + + def visitExprNot(self, en): + self.visitExprPrefixUnop(en) + + def visitExprCast(self, ec): + ec.expr.accept(self) + + def visitExprIndex(self, ei): + ei.arr.accept(self) + ei.idx.accept(self) + + def visitExprSelect(self, es): + es.obj.accept(self) + + def visitExprAssn(self, ea): + ea.lhs.accept(self) + ea.rhs.accept(self) + + def visitExprCall(self, ec): + ec.func.accept(self) + for arg in ec.args: + arg.accept(self) + + def visitExprNew(self, en): + en.ctype.accept(self) + if en.newargs is not None: + for arg in en.newargs: + arg.accept(self) + if en.args is not None: + for arg in en.args: + arg.accept(self) + + def visitExprDelete(self, ed): + ed.obj.accept(self) + + def visitExprMemberInit(self, minit): + self.visitExprCall(minit) + + def visitExprSizeof(self, es): + self.visitExprCall(es) + + def visitStmtBlock(self, sb): + self.visitBlock(sb) + + def visitStmtDecl(self, sd): + sd.decl.accept(self) + if sd.init is not None: + sd.init.accept(self) + + def visitLabel(self, label): + pass + + def visitCaseLabel(self, case): + pass + + def visitDefaultLabel(self, dl): + pass + + def visitStmtIf(self, si): + si.cond.accept(self) + si.ifb.accept(self) + if si.elseb is not None: + si.elseb.accept(self) + + def visitStmtFor(self, sf): + if sf.init is not None: + sf.init.accept(self) + if sf.cond is not None: + sf.cond.accept(self) + if sf.update is not None: + sf.update.accept(self) + + def visitStmtSwitch(self, ss): + ss.expr.accept(self) + self.visitBlock(ss) + + def visitStmtBreak(self, sb): + pass + + def visitStmtExpr(self, se): + se.expr.accept(self) + + def visitStmtReturn(self, sr): + if sr.expr is not None: + sr.expr.accept(self) + +##------------------------------ +class Node: + def __init__(self): + pass + + def accept(self, visitor): + visit = getattr(visitor, 'visit'+ self.__class__.__name__, None) + if visit is None: + return getattr(visitor, 'defaultVisit')(self) + return visit(self) + +class Whitespace(Node): + # yes, this is silly. but we need to stick comments in the + # generated code without resorting to more serious hacks + def __init__(self, ws, indent=0): + Node.__init__(self) + self.ws = ws + self.indent = indent +Whitespace.NL = Whitespace('\n') + +class File(Node): + def __init__(self, filename): + Node.__init__(self) + self.name = filename + # array of stuff in the file --- stmts and preprocessor thingies + self.stuff = [ ] + + def addthing(self, thing): + assert thing is not None + assert not isinstance(thing, list) + self.stuff.append(thing) + + def addthings(self, things): + for t in things: self.addthing(t) + + # "look like" a Block so code doesn't have to care whether they're + # in global scope or not + def addstmt(self, stmt): + assert stmt is not None + assert not isinstance(stmt, list) + self.stuff.append(stmt) + + def addstmts(self, stmts): + for s in stmts: self.addstmt(s) + +class CppDirective(Node): + '''represents |#[directive] [rest]|, where |rest| is any string''' + def __init__(self, directive, rest=None): + Node.__init__(self) + self.directive = directive + self.rest = rest + +class Block(Node): + def __init__(self): + Node.__init__(self) + self.stmts = [ ] + + def addstmt(self, stmt): + assert stmt is not None + assert not isinstance(stmt, tuple) + self.stmts.append(stmt) + + def addstmts(self, stmts): + for s in stmts: self.addstmt(s) + +##------------------------------ +# type and decl thingies +class Namespace(Block): + def __init__(self, name): + assert isinstance(name, str) + + Block.__init__(self) + self.name = name + +class Type(Node): + def __init__(self, name, const=0, + ptr=0, ptrconst=0, ptrptr=0, ptrconstptr=0, + ref=0, + hasimplicitcopyctor=True, + T=None): + """ +To avoid getting fancy with recursive types, we limit the kinds +of pointer types that can be be constructed. + + ptr => T* + ptrconst => T* const + ptrptr => T** + ptrconstptr => T* const* + +Any type, naked or pointer, can be const (const T) or ref (T&). +""" + assert isinstance(name, str) + assert not isinstance(const, str) + assert not isinstance(T, str) + + Node.__init__(self) + self.name = name + self.const = const + self.ptr = ptr + self.ptrconst = ptrconst + self.ptrptr = ptrptr + self.ptrconstptr = ptrconstptr + self.ref = ref + self.hasimplicitcopyctor = hasimplicitcopyctor + self.T = T + # XXX could get serious here with recursive types, but shouldn't + # need that for this codegen + def __deepcopy__(self, memo): + return Type(self.name, + const=self.const, + ptr=self.ptr, ptrconst=self.ptrconst, + ptrptr=self.ptrptr, ptrconstptr=self.ptrconstptr, + ref=self.ref, + T=copy.deepcopy(self.T, memo)) +Type.BOOL = Type('bool') +Type.INT = Type('int') +Type.INT32 = Type('int32_t') +Type.INTPTR = Type('intptr_t') +Type.NSRESULT = Type('nsresult') +Type.UINT32 = Type('uint32_t') +Type.UINT32PTR = Type('uint32_t', ptr=1) +Type.SIZE = Type('size_t') +Type.VOID = Type('void') +Type.VOIDPTR = Type('void', ptr=1) +Type.AUTO = Type('auto') + +class TypeArray(Node): + def __init__(self, basetype, nmemb): + '''the type |basetype DECLNAME[nmemb]|. |nmemb| is an Expr''' + self.basetype = basetype + self.nmemb = nmemb + def __deepcopy__(self, memo): + return TypeArray(deepcopy(self.basetype, memo), nmemb) + +class TypeEnum(Node): + def __init__(self, name=None): + '''name can be None''' + Node.__init__(self) + self.name = name + self.idnums = [ ] # pairs of ('Foo', [num]) or ('Foo', None) + + def addId(self, id, num=None): + self.idnums.append((id, num)) + +class TypeUnion(Node): + def __init__(self, name=None): + Node.__init__(self) + self.name = name + self.components = [ ] # [ Decl ] + + def addComponent(self, type, name): + self.components.append(Decl(type, name)) + +class Typedef(Node): + def __init__(self, fromtype, totypename, templateargs=[]): + assert isinstance(totypename, str) + + Node.__init__(self) + self.fromtype = fromtype + self.totypename = totypename + self.templateargs = templateargs + + def __cmp__(self, o): + return cmp(self.totypename, o.totypename) + def __eq__(self, o): + return (self.__class__ == o.__class__ + and 0 == cmp(self, o)) + def __hash__(self): + return hash(self.totypename) + +class Using(Node): + def __init__(self, type): + Node.__init__(self) + self.type = type + +class ForwardDecl(Node): + def __init__(self, pqname, cls=0, struct=0): + assert (not cls and struct) or (cls and not struct) + + self.pqname = pqname + self.cls = cls + self.struct = struct + +class Decl(Node): + '''represents |Foo bar|, e.g. in a function signature''' + def __init__(self, type, name): + assert type is not None + assert not isinstance(type, str) + assert isinstance(name, str) + + Node.__init__(self) + self.type = type + self.name = name + def __deepcopy__(self, memo): + return Decl(copy.deepcopy(self.type, memo), self.name) + +class Param(Decl): + def __init__(self, type, name, default=None): + Decl.__init__(self, type, name) + self.default = default + def __deepcopy__(self, memo): + return Param(copy.deepcopy(self.type, memo), self.name, + copy.deepcopy(self.default, memo)) + +##------------------------------ +# class stuff +class Class(Block): + def __init__(self, name, inherits=[ ], + interface=0, abstract=0, final=0, + specializes=None, struct=0): + assert not (interface and abstract) + assert not (abstract and final) + assert not (interface and final) + assert not (inherits and specializes) + + Block.__init__(self) + self.name = name + self.inherits = inherits # [ Type ] + self.interface = interface # bool + self.abstract = abstract # bool + self.final = final # bool + self.specializes = specializes # Type or None + self.struct = struct # bool + +class Inherit(Node): + def __init__(self, type, viz='public'): + assert isinstance(viz, str) + Node.__init__(self) + self.type = type + self.viz = viz + +class FriendClassDecl(Node): + def __init__(self, friend): + Node.__init__(self) + self.friend = friend + +class MethodDecl(Node): + def __init__(self, name, params=[ ], ret=Type('void'), + virtual=0, const=0, pure=0, static=0, warn_unused=0, + inline=0, force_inline=0, never_inline=0, + typeop=None, + T=None): + assert not (virtual and static) + assert not pure or virtual # pure => virtual + assert not (static and typeop) + assert not (name and typeop) + assert name is None or isinstance(name, str) + assert not isinstance(ret, list) + for decl in params: assert not isinstance(decl, str) + assert not isinstance(T, int) + assert not (inline and never_inline) + assert not (force_inline and never_inline) + + if typeop is not None: + ret = None + + Node.__init__(self) + self.name = name + self.params = params # [ Param ] + self.ret = ret # Type or None + self.virtual = virtual # bool + self.const = const # bool + self.pure = pure # bool + self.static = static # bool + self.warn_unused = warn_unused # bool + self.force_inline = (force_inline or T) # bool + self.inline = inline # bool + self.never_inline = never_inline # bool + self.typeop = typeop # Type or None + self.T = T # Type or None + self.only_for_definition = False + + def __deepcopy__(self, memo): + return MethodDecl( + self.name, + params=copy.deepcopy(self.params, memo), + ret=copy.deepcopy(self.ret, memo), + virtual=self.virtual, + const=self.const, + pure=self.pure, + static=self.static, + warn_unused=self.warn_unused, + inline=self.inline, + force_inline=self.force_inline, + never_inline=self.never_inline, + typeop=copy.deepcopy(self.typeop, memo), + T=copy.deepcopy(self.T, memo)) + +class MethodDefn(Block): + def __init__(self, decl): + Block.__init__(self) + self.decl = decl + +class FunctionDecl(MethodDecl): + def __init__(self, name, params=[ ], ret=Type('void'), + static=0, warn_unused=0, + inline=0, force_inline=0, + T=None): + MethodDecl.__init__(self, name, params=params, ret=ret, + static=static, warn_unused=warn_unused, + inline=inline, force_inline=force_inline, + T=T) + +class FunctionDefn(MethodDefn): + def __init__(self, decl): + MethodDefn.__init__(self, decl) + +class ConstructorDecl(MethodDecl): + def __init__(self, name, params=[ ], explicit=0, force_inline=0): + MethodDecl.__init__(self, name, params=params, ret=None, + force_inline=force_inline) + self.explicit = explicit + + def __deepcopy__(self, memo): + return ConstructorDecl(self.name, + copy.deepcopy(self.params, memo), + self.explicit) + +class ConstructorDefn(MethodDefn): + def __init__(self, decl, memberinits=[ ]): + MethodDefn.__init__(self, decl) + self.memberinits = memberinits + +class DestructorDecl(MethodDecl): + def __init__(self, name, virtual=0, force_inline=0, inline=0): + MethodDecl.__init__(self, name, params=[ ], ret=None, + virtual=virtual, + force_inline=force_inline, inline=inline) + + def __deepcopy__(self, memo): + return DestructorDecl(self.name, + virtual=self.virtual, + force_inline=self.force_inline, + inline=self.inline) + + +class DestructorDefn(MethodDefn): + def __init__(self, decl): MethodDefn.__init__(self, decl) + +##------------------------------ +# expressions +class ExprVar(Node): + def __init__(self, name): + assert isinstance(name, str) + + Node.__init__(self) + self.name = name +ExprVar.THIS = ExprVar('this') + +class ExprLiteral(Node): + def __init__(self, value, type): + '''|type| is a Python format specifier; 'd' for example''' + Node.__init__(self) + self.value = value + self.type = type + + @staticmethod + def Int(i): return ExprLiteral(i, 'd') + + @staticmethod + def String(s): return ExprLiteral('"'+ s +'"', 's') + + @staticmethod + def WString(s): return ExprLiteral('L"'+ s +'"', 's') + + def __str__(self): + return ('%'+ self.type)% (self.value) +ExprLiteral.ZERO = ExprLiteral.Int(0) +ExprLiteral.ONE = ExprLiteral.Int(1) +ExprLiteral.NULL = ExprVar('nullptr') +ExprLiteral.TRUE = ExprVar('true') +ExprLiteral.FALSE = ExprVar('false') + +class ExprPrefixUnop(Node): + def __init__(self, expr, op): + assert not isinstance(expr, tuple) + self.expr = expr + self.op = op + +class ExprNot(ExprPrefixUnop): + def __init__(self, expr): + ExprPrefixUnop.__init__(self, expr, '!') + +class ExprAddrOf(ExprPrefixUnop): + def __init__(self, expr): + ExprPrefixUnop.__init__(self, expr, '&') + +class ExprDeref(ExprPrefixUnop): + def __init__(self, expr): + ExprPrefixUnop.__init__(self, expr, '*') + +class ExprCast(Node): + def __init__(self, expr, type, + dynamic=0, static=0, reinterpret=0, const=0, C=0): + assert 1 == reduce(lambda a, x: a+x, [ dynamic, static, reinterpret, const, C ]) + + Node.__init__(self) + self.expr = expr + self.type = type + self.dynamic = dynamic + self.static = static + self.reinterpret = reinterpret + self.const = const + self.C = C + +class ExprBinary(Node): + def __init__(self, left, op, right): + Node.__init__(self) + self.left = left + self.op = op + self.right = right + +class ExprConditional(Node): + def __init__(self, cond, ife, elsee): + Node.__init__(self) + self.cond = cond + self.ife = ife + self.elsee = elsee + +class ExprIndex(Node): + def __init__(self, arr, idx): + Node.__init__(self) + self.arr = arr + self.idx = idx + +class ExprSelect(Node): + def __init__(self, obj, op, field): + assert obj and op and field + assert not isinstance(obj, str) + assert isinstance(field, str) + + Node.__init__(self) + self.obj = obj + self.op = op + self.field = field + +class ExprAssn(Node): + def __init__(self, lhs, rhs, op='='): + Node.__init__(self) + self.lhs = lhs + self.op = op + self.rhs = rhs + +class ExprCall(Node): + def __init__(self, func, args=[ ]): + assert hasattr(func, 'accept') + assert isinstance(args, list) + for arg in args: assert arg and not isinstance(arg, str) + + Node.__init__(self) + self.func = func + self.args = args + +class ExprMove(ExprCall): + def __init__(self, arg): + ExprCall.__init__(self, ExprVar("mozilla::Move"), args=[arg]) + +class ExprNew(Node): + # XXX taking some poetic license ... + def __init__(self, ctype, args=[ ], newargs=None): + assert not (ctype.const or ctype.ref) + + Node.__init__(self) + self.ctype = ctype + self.args = args + self.newargs = newargs + +class ExprDelete(Node): + def __init__(self, obj): + Node.__init__(self) + self.obj = obj + +class ExprMemberInit(ExprCall): + def __init__(self, member, args=[ ]): + ExprCall.__init__(self, member, args) + +class ExprSizeof(ExprCall): + def __init__(self, t): + ExprCall.__init__(self, ExprVar('sizeof'), [ t ]) + +##------------------------------ +# statements etc. +class StmtBlock(Block): + def __init__(self, stmts=[ ]): + Block.__init__(self) + self.addstmts(stmts) + +class StmtDecl(Node): + def __init__(self, decl, init=None, initargs=None): + assert not (init and initargs) + assert not isinstance(init, str) # easy to confuse with Decl + assert not isinstance(init, list) + assert not isinstance(decl, tuple) + + Node.__init__(self) + self.decl = decl + self.init = init + self.initargs = initargs + +class Label(Node): + def __init__(self, name): + Node.__init__(self) + self.name = name +Label.PUBLIC = Label('public') +Label.PROTECTED = Label('protected') +Label.PRIVATE = Label('private') + +class CaseLabel(Node): + def __init__(self, name): + Node.__init__(self) + self.name = name + +class DefaultLabel(Node): + def __init__(self): + Node.__init__(self) + +class StmtIf(Node): + def __init__(self, cond): + Node.__init__(self) + self.cond = cond + self.ifb = Block() + self.elseb = None + + def addifstmt(self, stmt): + self.ifb.addstmt(stmt) + + def addifstmts(self, stmts): + self.ifb.addstmts(stmts) + + def addelsestmt(self, stmt): + if self.elseb is None: self.elseb = Block() + self.elseb.addstmt(stmt) + + def addelsestmts(self, stmts): + if self.elseb is None: self.elseb = Block() + self.elseb.addstmts(stmts) + +class StmtFor(Block): + def __init__(self, init=None, cond=None, update=None): + Block.__init__(self) + self.init = init + self.cond = cond + self.update = update + +class StmtRangedFor(Block): + def __init__(self, var, iteree): + assert isinstance(var, ExprVar) + assert iteree + + Block.__init__(self) + self.var = var + self.iteree = iteree + +class StmtSwitch(Block): + def __init__(self, expr): + Block.__init__(self) + self.expr = expr + self.nr_cases = 0 + + def addcase(self, case, block): + '''NOTE: |case| is not checked for uniqueness''' + assert not isinstance(case, str) + assert (isinstance(block, StmtBreak) + or isinstance(block, StmtReturn) + or isinstance(block, StmtSwitch) + or (hasattr(block, 'stmts') + and (isinstance(block.stmts[-1], StmtBreak) + or isinstance(block.stmts[-1], StmtReturn)))) + self.addstmt(case) + self.addstmt(block) + self.nr_cases += 1 + + def addfallthrough(self, case): + self.addstmt(case) + self.nr_cases += 1 + +class StmtBreak(Node): + def __init__(self): + Node.__init__(self) + +class StmtExpr(Node): + def __init__(self, expr): + assert expr is not None + + Node.__init__(self) + self.expr = expr + +class StmtReturn(Node): + def __init__(self, expr=None): + Node.__init__(self) + self.expr = expr + +StmtReturn.TRUE = StmtReturn(ExprLiteral.TRUE) +StmtReturn.FALSE = StmtReturn(ExprLiteral.FALSE) diff --git a/ipc/ipdl/ipdl/cxx/cgen.py b/ipc/ipdl/ipdl/cxx/cgen.py new file mode 100644 index 000000000..30f2f2bca --- /dev/null +++ b/ipc/ipdl/ipdl/cxx/cgen.py @@ -0,0 +1,520 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import sys + +from ipdl.cgen import CodePrinter +from ipdl.cxx.ast import TypeArray, Visitor + +class CxxCodeGen(CodePrinter, Visitor): + def __init__(self, outf=sys.stdout, indentCols=4): + CodePrinter.__init__(self, outf, indentCols) + + def cgen(self, cxxfile): + cxxfile.accept(self) + + def visitWhitespace(self, ws): + if ws.indent: + self.printdent('') + self.write(ws.ws) + + def visitCppDirective(self, cd): + if cd.rest: + self.println('#%s %s'% (cd.directive, cd.rest)) + else: + self.println('#%s'% (cd.directive)) + + def visitNamespace(self, ns): + self.println('namespace '+ ns.name +' {') + self.visitBlock(ns) + self.println('} // namespace '+ ns.name) + + def visitType(self, t): + if t.const: + self.write('const ') + + self.write(t.name) + + if t.T is not None: + self.write('<') + t.T.accept(self) + self.write('>') + + ts = '' + if t.ptr: ts += '*' + elif t.ptrconst: ts += '* const' + elif t.ptrptr: ts += '**' + elif t.ptrconstptr: ts += '* const*' + + ts += '&' * t.ref + + self.write(ts) + + def visitTypeEnum(self, te): + self.write('enum') + if te.name: + self.write(' '+ te.name) + self.println(' {') + + self.indent() + nids = len(te.idnums) + for i, (id, num) in enumerate(te.idnums): + self.printdent(id) + if num: + self.write(' = '+ str(num)) + if i != (nids-1): + self.write(',') + self.println() + self.dedent() + self.printdent('}') + + def visitTypeUnion(self, u): + self.write('union') + if u.name: + self.write(' '+ u.name) + self.println(' {') + + self.indent() + for decl in u.components: + self.printdent() + decl.accept(self) + self.println(';') + self.dedent() + + self.printdent('}') + + + def visitTypedef(self, td): + if td.templateargs: + formals = ', '.join([ 'class ' + T for T in td.templateargs ]) + args = ', '.join(td.templateargs) + self.printdent('template<' + formals + '> using ' + td.totypename + ' = ') + td.fromtype.accept(self) + self.println('<' + args + '>;') + else: + self.printdent('typedef ') + td.fromtype.accept(self) + self.println(' '+ td.totypename +';') + + def visitUsing(self, us): + self.printdent('using ') + us.type.accept(self) + self.println(';') + + def visitForwardDecl(self, fd): + if fd.cls: self.printdent('class ') + elif fd.struct: self.printdent('struct ') + self.write(str(fd.pqname)) + self.println(';') + + def visitDecl(self, d): + # C-syntax arrays make code generation much more annoying + if isinstance(d.type, TypeArray): + d.type.basetype.accept(self) + else: + d.type.accept(self) + + if d.name: + self.write(' '+ d.name) + + if isinstance(d.type, TypeArray): + self.write('[') + d.type.nmemb.accept(self) + self.write(']') + + def visitParam(self, p): + self.visitDecl(p) + if p.default is not None: + self.write(' = ') + p.default.accept(self) + + def visitClass(self, c): + if c.specializes is not None: + self.printdentln('template<>') + + if c.struct: + self.printdent('struct') + else: + self.printdent('class') + self.write(' '+ c.name) + if c.final: + self.write(' final') + + if c.specializes is not None: + self.write(' <') + c.specializes.accept(self) + self.write('>') + + ninh = len(c.inherits) + if 0 < ninh: + self.println(' :') + self.indent() + for i, inherit in enumerate(c.inherits): + self.printdent() + inherit.accept(self) + if i != (ninh - 1): + self.println(',') + self.dedent() + self.println() + + self.printdentln('{') + self.indent() + + self.visitBlock(c) + + self.dedent() + self.printdentln('};') + + def visitInherit(self, inh): + self.write(inh.viz +' ') + inh.type.accept(self) + + def visitFriendClassDecl(self, fcd): + self.printdentln('friend class '+ fcd.friend +';') + + + def visitMethodDecl(self, md): + assert not (md.static and md.virtual) + + if md.T: + self.write('template<') + self.write('typename ') + md.T.accept(self) + self.println('>') + self.printdent() + + if md.warn_unused: + self.write('MOZ_MUST_USE ') + if md.inline: + self.write('inline ') + if md.never_inline: + self.write('MOZ_NEVER_INLINE ') + if md.static: + self.write('static ') + if md.virtual: + self.write('virtual ') + if md.ret: + if md.only_for_definition: + self.write('auto ') + else: + md.ret.accept(self) + self.println() + self.printdent() + if md.typeop is not None: + self.write('operator ') + md.typeop.accept(self) + else: + self.write(md.name) + + self.write('(') + self.writeDeclList(md.params) + self.write(')') + + if md.const: + self.write(' const') + if md.ret and md.only_for_definition: + self.write(' -> ') + md.ret.accept(self) + if md.pure: + self.write(' = 0') + + + def visitMethodDefn(self, md): + if md.decl.pure: + return + + self.printdent() + md.decl.accept(self) + self.println() + + self.printdentln('{') + self.indent() + self.visitBlock(md) + self.dedent() + self.printdentln('}') + + + def visitConstructorDecl(self, cd): + if cd.explicit: + self.write('explicit ') + else: + self.write('MOZ_IMPLICIT ') + self.visitMethodDecl(cd) + + def visitConstructorDefn(self, cd): + self.printdent() + cd.decl.accept(self) + if len(cd.memberinits): + self.println(' :') + self.indent() + ninits = len(cd.memberinits) + for i, init in enumerate(cd.memberinits): + self.printdent() + init.accept(self) + if i != (ninits-1): + self.println(',') + self.dedent() + self.println() + + self.printdentln('{') + self.indent() + + self.visitBlock(cd) + + self.dedent() + self.printdentln('}') + + + def visitDestructorDecl(self, dd): + if dd.inline: + self.write('inline ') + if dd.virtual: + self.write('virtual ') + + # hack alert + parts = dd.name.split('::') + parts[-1] = '~'+ parts[-1] + + self.write('::'.join(parts) +'()') + + def visitDestructorDefn(self, dd): + self.printdent() + dd.decl.accept(self) + self.println() + + self.printdentln('{') + self.indent() + + self.visitBlock(dd) + + self.dedent() + self.printdentln('}') + + + def visitExprLiteral(self, el): + self.write(str(el)) + + def visitExprVar(self, ev): + self.write(ev.name) + + def visitExprPrefixUnop(self, e): + self.write('(') + self.write(e.op) + self.write('(') + e.expr.accept(self) + self.write(')') + self.write(')') + + def visitExprCast(self, c): + pfx, sfx = '', '' + if c.dynamic: pfx, sfx = 'dynamic_cast<', '>' + elif c.static: pfx, sfx = 'static_cast<', '>' + elif c.reinterpret: pfx, sfx = 'reinterpret_cast<', '>' + elif c.const: pfx, sfx = 'const_cast<', '>' + elif c.C: pfx, sfx = '(', ')' + self.write(pfx) + c.type.accept(self) + self.write(sfx +'(') + c.expr.accept(self) + self.write(')') + + def visitExprBinary(self, e): + self.write('(') + e.left.accept(self) + self.write(') '+ e.op +' (') + e.right.accept(self) + self.write(')') + + def visitExprConditional(self, c): + self.write('(') + c.cond.accept(self) + self.write(' ? ') + c.ife.accept(self) + self.write(' : ') + c.elsee.accept(self) + self.write(')') + + def visitExprIndex(self, ei): + ei.arr.accept(self) + self.write('[') + ei.idx.accept(self) + self.write(']') + + def visitExprSelect(self, es): + self.write('(') + es.obj.accept(self) + self.write(')') + self.write(es.op + es.field) + + def visitExprAssn(self, ea): + ea.lhs.accept(self) + self.write(' '+ ea.op +' ') + ea.rhs.accept(self) + + def visitExprCall(self, ec): + ec.func.accept(self) + self.write('(') + self.writeExprList(ec.args) + self.write(')') + + def visitExprMove(self, em): + self.visitExprCall(em) + + def visitExprNew(self, en): + self.write('new ') + if en.newargs is not None: + self.write('(') + self.writeExprList(en.newargs) + self.write(') ') + en.ctype.accept(self) + if en.args is not None: + self.write('(') + self.writeExprList(en.args) + self.write(')') + + def visitExprDelete(self, ed): + self.write('delete ') + ed.obj.accept(self) + + + def visitStmtBlock(self, b): + self.printdentln('{') + self.indent() + self.visitBlock(b) + self.dedent() + self.printdentln('}') + + def visitLabel(self, label): + self.dedent() # better not be at global scope ... + self.printdentln(label.name +':') + self.indent() + + def visitCaseLabel(self, cl): + self.dedent() + self.printdentln('case '+ cl.name +':') + self.indent() + + def visitDefaultLabel(self, dl): + self.dedent() + self.printdentln('default:') + self.indent() + + + def visitStmtIf(self, si): + self.printdent('if (') + si.cond.accept(self) + self.println(') {') + self.indent() + si.ifb.accept(self) + self.dedent() + self.printdentln('}') + + if si.elseb is not None: + self.printdentln('else {') + self.indent() + si.elseb.accept(self) + self.dedent() + self.printdentln('}') + + + def visitStmtFor(self, sf): + self.printdent('for (') + if sf.init is not None: + sf.init.accept(self) + self.write('; ') + if sf.cond is not None: + sf.cond.accept(self) + self.write('; ') + if sf.update is not None: + sf.update.accept(self) + self.println(') {') + + self.indent() + self.visitBlock(sf) + self.dedent() + self.printdentln('}') + + + def visitStmtRangedFor(self, rf): + self.printdent('for (auto& ') + rf.var.accept(self) + self.write(' : ') + rf.iteree.accept(self) + self.println(') {') + + self.indent() + self.visitBlock(rf) + self.dedent() + self.printdentln('}') + + + def visitStmtSwitch(self, sw): + self.printdent('switch (') + sw.expr.accept(self) + self.println(') {') + self.indent() + self.visitBlock(sw) + self.dedent() + self.printdentln('}') + + def visitStmtBreak(self, sb): + self.printdentln('break;') + + + def visitStmtDecl(self, sd): + self.printdent() + sd.decl.accept(self) + if sd.initargs is not None: + self.write('(') + self.writeDeclList(sd.initargs) + self.write(')') + if sd.init is not None: + self.write(' = ') + sd.init.accept(self) + self.println(';') + + + def visitStmtExpr(self, se): + self.printdent() + se.expr.accept(self) + self.println(';') + + + def visitStmtReturn(self, sr): + self.printdent('return') + if sr.expr: + self.write (' ') + sr.expr.accept(self) + self.println(';') + + + def writeDeclList(self, decls): + # FIXME/cjones: try to do nice formatting of these guys + + ndecls = len(decls) + if 0 == ndecls: + return + elif 1 == ndecls: + decls[0].accept(self) + return + + self.indent() + self.indent() + for i, decl in enumerate(decls): + self.println() + self.printdent() + decl.accept(self) + if i != (ndecls-1): + self.write(',') + self.dedent() + self.dedent() + + def writeExprList(self, exprs): + # FIXME/cjones: try to do nice formatting and share code with + # writeDeclList() + nexprs = len(exprs) + for i, expr in enumerate(exprs): + expr.accept(self) + if i != (nexprs-1): + self.write(', ') diff --git a/ipc/ipdl/ipdl/lower.py b/ipc/ipdl/ipdl/lower.py new file mode 100644 index 000000000..f810cccb0 --- /dev/null +++ b/ipc/ipdl/ipdl/lower.py @@ -0,0 +1,4822 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import os, re, sys +from copy import deepcopy +from collections import OrderedDict + +import ipdl.ast +import ipdl.builtin +from ipdl.cxx.ast import * +from ipdl.type import Actor, ActorType, ProcessGraph, TypeVisitor, builtinHeaderIncludes + +##----------------------------------------------------------------------------- +## "Public" interface to lowering +## +class LowerToCxx: + def lower(self, tu): + '''returns |[ header: File ], [ cpp : File ]| representing the +lowered form of |tu|''' + # annotate the AST with IPDL/C++ IR-type stuff used later + tu.accept(_DecorateWithCxxStuff()) + + # Any modifications to the filename scheme here need corresponding + # modifications in the ipdl.py driver script. + name = tu.name + pheader, pcpp = File(name +'.h'), File(name +'.cpp') + + _GenerateProtocolCode().lower(tu, pheader, pcpp) + headers = [ pheader ] + cpps = [ pcpp ] + + if tu.protocol: + pname = tu.protocol.name + + parentheader, parentcpp = File(pname +'Parent.h'), File(pname +'Parent.cpp') + _GenerateProtocolParentCode().lower( + tu, pname+'Parent', parentheader, parentcpp) + + childheader, childcpp = File(pname +'Child.h'), File(pname +'Child.cpp') + _GenerateProtocolChildCode().lower( + tu, pname+'Child', childheader, childcpp) + + headers += [ parentheader, childheader ] + cpps += [ parentcpp, childcpp ] + + return headers, cpps + + +##----------------------------------------------------------------------------- +## Helper code +## + +def hashfunc(value): + h = hash(value) % 2**32 + if h < 0: h += 2**32 + return h + +_NULL_ACTOR_ID = ExprLiteral.ZERO +_FREED_ACTOR_ID = ExprLiteral.ONE + +_DISCLAIMER = Whitespace('''// +// Automatically generated by ipdlc. +// Edit at your own risk +// + +''') + + +class _struct: pass + +def _namespacedHeaderName(name, namespaces): + pfx = '/'.join([ ns.name for ns in namespaces ]) + if pfx: + return pfx +'/'+ name + else: + return name + +def _ipdlhHeaderName(tu): + assert tu.filetype == 'header' + return _namespacedHeaderName(tu.name, tu.namespaces) + +def _protocolHeaderName(p, side=''): + if side: side = side.title() + base = p.name + side + return _namespacedHeaderName(base, p.namespaces) + +def _includeGuardMacroName(headerfile): + return re.sub(r'[./]', '_', headerfile.name) + +def _includeGuardStart(headerfile): + guard = _includeGuardMacroName(headerfile) + return [ CppDirective('ifndef', guard), + CppDirective('define', guard) ] + +def _includeGuardEnd(headerfile): + guard = _includeGuardMacroName(headerfile) + return [ CppDirective('endif', '// ifndef '+ guard) ] + +def _messageStartName(ptype): + return ptype.name() +'MsgStart' + +def _protocolId(ptype): + return ExprVar(_messageStartName(ptype)) + +def _protocolIdType(): + return Type.INT32 + +def _actorName(pname, side): + """|pname| is the protocol name. |side| is 'Parent' or 'Child'.""" + tag = side + if not tag[0].isupper(): tag = side.title() + return pname + tag + +def _actorIdType(): + return Type.INT32 + +def _actorTypeTagType(): + return Type.INT32 + +def _actorId(actor=None): + if actor is not None: + return ExprCall(ExprSelect(actor, '->', 'Id')) + return ExprCall(ExprVar('Id')) + +def _actorHId(actorhandle): + return ExprSelect(actorhandle, '.', 'mId') + +def _actorManager(actor): + return ExprCall(ExprSelect(actor, '->', 'Manager'), args=[]) + +def _actorState(actor): + return ExprSelect(actor, '->', 'mState') + +def _backstagePass(): + return ExprCall(ExprVar('mozilla::ipc::PrivateIPDLInterface')) + +def _iterType(ptr): + return Type('PickleIterator', ptr=ptr) + +def _nullState(proto=None): + pfx = '' + if proto is not None: pfx = proto.name() +'::' + return ExprVar(pfx +'__Null') + +def _errorState(proto=None): + pfx = '' + if proto is not None: pfx = proto.name() +'::' + return ExprVar(pfx +'__Error') + +def _deadState(proto=None): + pfx = '' + if proto is not None: pfx = proto.name() +'::' + return ExprVar(pfx +'__Dead') + +def _dyingState(proto=None): + pfx = '' + if proto is not None: pfx = proto.name() +'::' + return ExprVar(pfx +'__Dying') + +def _startState(proto=None, fq=False): + pfx = '' + if proto: + if fq: pfx = proto.fullname() +'::' + else: pfx = proto.name() +'::' + return ExprVar(pfx +'__Start') + +def _deleteId(): + return ExprVar('Msg___delete____ID') + +def _deleteReplyId(): + return ExprVar('Reply___delete____ID') + +def _lookupListener(idexpr): + return ExprCall(ExprVar('Lookup'), args=[ idexpr ]) + +def _shmemType(ptr=0, const=1, ref=0): + return Type('Shmem', ptr=ptr, ref=ref) + +def _rawShmemType(ptr=0): + return Type('Shmem::SharedMemory', ptr=ptr) + +def _shmemIdType(ptr=0): + return Type('Shmem::id_t', ptr=ptr) + +def _shmemTypeType(): + return Type('Shmem::SharedMemory::SharedMemoryType') + +def _shmemBackstagePass(): + return ExprCall(ExprVar( + 'Shmem::IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead')) + +def _shmemCtor(rawmem, idexpr): + return ExprCall(ExprVar('Shmem'), + args=[ _shmemBackstagePass(), rawmem, idexpr ]) + +def _shmemId(shmemexpr): + return ExprCall(ExprSelect(shmemexpr, '.', 'Id'), + args=[ _shmemBackstagePass() ]) + +def _shmemSegment(shmemexpr): + return ExprCall(ExprSelect(shmemexpr, '.', 'Segment'), + args=[ _shmemBackstagePass() ]) + +def _shmemAlloc(size, type, unsafe): + # starts out UNprotected + return ExprCall(ExprVar('Shmem::Alloc'), + args=[ _shmemBackstagePass(), size, type, unsafe ]) + +def _shmemDealloc(rawmemvar): + return ExprCall(ExprVar('Shmem::Dealloc'), + args=[ _shmemBackstagePass(), rawmemvar ]) + +def _shmemShareTo(shmemvar, processvar, route): + return ExprCall(ExprSelect(shmemvar, '.', 'ShareTo'), + args=[ _shmemBackstagePass(), + processvar, route ]) + +def _shmemOpenExisting(descriptor, outid): + # starts out protected + return ExprCall(ExprVar('Shmem::OpenExisting'), + args=[ _shmemBackstagePass(), + # true => protect + descriptor, outid, ExprLiteral.TRUE ]) + +def _shmemUnshareFrom(shmemvar, processvar, route): + return ExprCall(ExprSelect(shmemvar, '.', 'UnshareFrom'), + args=[ _shmemBackstagePass(), + processvar, route ]) + +def _shmemForget(shmemexpr): + return ExprCall(ExprSelect(shmemexpr, '.', 'forget'), + args=[ _shmemBackstagePass() ]) + +def _shmemRevokeRights(shmemexpr): + return ExprCall(ExprSelect(shmemexpr, '.', 'RevokeRights'), + args=[ _shmemBackstagePass() ]) + +def _lookupShmem(idexpr): + return ExprCall(ExprVar('LookupSharedMemory'), args=[ idexpr ]) + +def _makeForwardDeclForQClass(clsname, quals, cls=1, struct=0): + fd = ForwardDecl(clsname, cls=cls, struct=struct) + if 0 == len(quals): + return fd + + outerns = Namespace(quals[0]) + innerns = outerns + for ns in quals[1:]: + tmpns = Namespace(ns) + innerns.addstmt(tmpns) + innerns = tmpns + + innerns.addstmt(fd) + return outerns + +def _makeForwardDeclForActor(ptype, side): + return _makeForwardDeclForQClass(_actorName(ptype.qname.baseid, side), + ptype.qname.quals) + +def _makeForwardDecl(type): + return _makeForwardDeclForQClass(type.name(), type.qname.quals) + + +def _putInNamespaces(cxxthing, namespaces): + """|namespaces| is in order [ outer, ..., inner ]""" + if 0 == len(namespaces): return cxxthing + + outerns = Namespace(namespaces[0].name) + innerns = outerns + for ns in namespaces[1:]: + newns = Namespace(ns.name) + innerns.addstmt(newns) + innerns = newns + innerns.addstmt(cxxthing) + return outerns + +def _sendPrefix(msgtype): + """Prefix of the name of the C++ method that sends |msgtype|.""" + if msgtype.isInterrupt(): + return 'Call' + return 'Send' + +def _recvPrefix(msgtype): + """Prefix of the name of the C++ method that handles |msgtype|.""" + if msgtype.isInterrupt(): + return 'Answer' + return 'Recv' + +def _flatTypeName(ipdltype): + """Return a 'flattened' IPDL type name that can be used as an +identifier. +E.g., |Foo[]| --> |ArrayOfFoo|.""" + # NB: this logic depends heavily on what IPDL types are allowed to + # be constructed; e.g., Foo[][] is disallowed. needs to be kept in + # sync with grammar. + if ipdltype.isIPDL() and ipdltype.isArray(): + return 'ArrayOf'+ ipdltype.basetype.name() + return ipdltype.name() + + +def _hasVisibleActor(ipdltype): + """Return true iff a C++ decl of |ipdltype| would have an Actor* type. +For example: |Actor[]| would turn into |Array<ActorParent*>|, so this +function would return true for |Actor[]|.""" + return (ipdltype.isIPDL() + and (ipdltype.isActor() + or (ipdltype.isArray() + and _hasVisibleActor(ipdltype.basetype)))) + +def _abortIfFalse(cond, msg): + return StmtExpr(ExprCall( + ExprVar('MOZ_RELEASE_ASSERT'), + [ cond, ExprLiteral.String(msg) ])) + +def _refptr(T): + return Type('RefPtr', T=T) + +def _refptrGet(expr): + return ExprCall(ExprSelect(expr, '.', 'get')) + +def _refptrForget(expr): + return ExprCall(ExprSelect(expr, '.', 'forget')) + +def _refptrTake(expr): + return ExprCall(ExprSelect(expr, '.', 'take')) + +def _uniqueptr(T): + return Type('UniquePtr', T=T) + +def _uniqueptrGet(expr): + return ExprCall(ExprSelect(expr, '.', 'get')) + +def _cxxArrayType(basetype, const=0, ref=0): + return Type('nsTArray', T=basetype, const=const, ref=ref, hasimplicitcopyctor=False) + +def _cxxManagedContainerType(basetype, const=0, ref=0): + return Type('ManagedContainer', T=basetype, + const=const, ref=ref, hasimplicitcopyctor=False) + +def _callCxxArrayLength(arr): + return ExprCall(ExprSelect(arr, '.', 'Length')) + +def _callCxxArraySetLength(arr, lenexpr, sel='.'): + return ExprCall(ExprSelect(arr, sel, 'SetLength'), + args=[ lenexpr ]) + +def _callCxxSwapArrayElements(arr1, arr2, sel='.'): + return ExprCall(ExprSelect(arr1, sel, 'SwapElements'), + args=[ arr2 ]) + +def _callInsertManagedActor(managees, actor): + return ExprCall(ExprSelect(managees, '.', 'PutEntry'), + args=[ actor ]) + +def _callRemoveManagedActor(managees, actor): + return ExprCall(ExprSelect(managees, '.', 'RemoveEntry'), + args=[ actor ]) + +def _callClearManagedActors(managees): + return ExprCall(ExprSelect(managees, '.', 'Clear')) + +def _callHasManagedActor(managees, actor): + return ExprCall(ExprSelect(managees, '.', 'Contains'), args=[ actor ]) + +def _otherSide(side): + if side == 'child': return 'parent' + if side == 'parent': return 'child' + assert 0 + +def _sideToTransportMode(side): + if side == 'parent': mode = 'SERVER' + elif side == 'child': mode = 'CLIENT' + return ExprVar('mozilla::ipc::Transport::MODE_'+ mode) + +def _ifLogging(topLevelProtocol, stmts): + iflogging = StmtIf(ExprCall(ExprVar('mozilla::ipc::LoggingEnabledFor'), + args=[ topLevelProtocol ])) + iflogging.addifstmts(stmts) + return iflogging + +# XXX we need to remove these and install proper error handling +def _printErrorMessage(msg): + if isinstance(msg, str): + msg = ExprLiteral.String(msg) + return StmtExpr( + ExprCall(ExprVar('NS_ERROR'), args=[ msg ])) + +def _protocolErrorBreakpoint(msg): + if isinstance(msg, str): + msg = ExprLiteral.String(msg) + return StmtExpr(ExprCall(ExprVar('mozilla::ipc::ProtocolErrorBreakpoint'), + args=[ msg ])) + +def _printWarningMessage(msg): + if isinstance(msg, str): + msg = ExprLiteral.String(msg) + return StmtExpr( + ExprCall(ExprVar('NS_WARNING'), args=[ msg ])) + +def _fatalError(msg): + return StmtExpr( + ExprCall(ExprVar('FatalError'), args=[ ExprLiteral.String(msg) ])) + +def _logicError(msg): + return StmtExpr( + ExprCall(ExprVar('mozilla::ipc::LogicError'), args=[ ExprLiteral.String(msg) ])) + +def _arrayLengthReadError(elementname): + return StmtExpr( + ExprCall(ExprVar('mozilla::ipc::ArrayLengthReadError'), + args=[ ExprLiteral.String(elementname) ])) + +def _unionTypeReadError(unionname): + return StmtExpr( + ExprCall(ExprVar('mozilla::ipc::UnionTypeReadError'), + args=[ ExprLiteral.String(unionname) ])) + +def _killProcess(pid): + return ExprCall( + ExprVar('base::KillProcess'), + args=[ pid, + # XXX this is meaningless on POSIX + ExprVar('base::PROCESS_END_KILLED_BY_USER'), + ExprLiteral.FALSE ]) + +def _badTransition(): + # FIXME: make this a FatalError() + return [ _printWarningMessage('bad state transition!') ] + +# Results that IPDL-generated code returns back to *Channel code. +# Users never see these +class _Result: + @staticmethod + def Type(): + return Type('Result') + + Processed = ExprVar('MsgProcessed') + NotKnown = ExprVar('MsgNotKnown') + NotAllowed = ExprVar('MsgNotAllowed') + PayloadError = ExprVar('MsgPayloadError') + ProcessingError = ExprVar('MsgProcessingError') + RouteError = ExprVar('MsgRouteError') + ValuError = ExprVar('MsgValueError') # [sic] + +# these |errfn*| are functions that generate code to be executed on an +# error, such as "bad actor ID". each is given a Python string +# containing a description of the error + +# used in user-facing Send*() methods +def errfnSend(msg, errcode=ExprLiteral.FALSE): + return [ + _fatalError(msg), + StmtReturn(errcode) + ] + +def errfnSendCtor(msg): return errfnSend(msg, errcode=ExprLiteral.NULL) + +# TODO should this error handling be strengthened for dtors? +def errfnSendDtor(msg): + return [ + _printErrorMessage(msg), + StmtReturn.FALSE + ] + +# used in |OnMessage*()| handlers that hand in-messages off to Recv*() +# interface methods +def errfnRecv(msg, errcode=_Result.ValuError): + return [ + _fatalError(msg), + StmtReturn(errcode) + ] + +# used in Read() methods +def errfnRead(msg): + return [ _fatalError(msg), StmtReturn.FALSE ] + +def errfnArrayLength(elementname): + return [ _arrayLengthReadError(elementname), StmtReturn.FALSE ] + +def errfnUnionType(unionname): + return [ _unionTypeReadError(unionname), StmtReturn.FALSE ] + +def _destroyMethod(): + return ExprVar('ActorDestroy') + +class _DestroyReason: + @staticmethod + def Type(): return Type('ActorDestroyReason') + + Deletion = ExprVar('Deletion') + AncestorDeletion = ExprVar('AncestorDeletion') + NormalShutdown = ExprVar('NormalShutdown') + AbnormalShutdown = ExprVar('AbnormalShutdown') + FailedConstructor = ExprVar('FailedConstructor') + +##----------------------------------------------------------------------------- +## Intermediate representation (IR) nodes used during lowering + +class _ConvertToCxxType(TypeVisitor): + def __init__(self, side, fq): + self.side = side + self.fq = fq + + def typename(self, thing): + if self.fq: + return thing.fullname() + return thing.name() + + def visitBuiltinCxxType(self, t): + return Type(self.typename(t)) + + def visitImportedCxxType(self, t): + return Type(self.typename(t)) + + def visitActorType(self, a): + return Type(_actorName(self.typename(a.protocol), self.side), ptr=1) + + def visitStructType(self, s): + return Type(self.typename(s)) + + def visitUnionType(self, u): + return Type(self.typename(u)) + + def visitArrayType(self, a): + basecxxtype = a.basetype.accept(self) + return _cxxArrayType(basecxxtype) + + def visitShmemType(self, s): + return Type(self.typename(s)) + + def visitFDType(self, s): + return Type(self.typename(s)) + + def visitEndpointType(self, s): + return Type(self.typename(s)) + + def visitProtocolType(self, p): assert 0 + def visitMessageType(self, m): assert 0 + def visitVoidType(self, v): assert 0 + def visitStateType(self, st): assert 0 + +def _cxxBareType(ipdltype, side, fq=0): + return ipdltype.accept(_ConvertToCxxType(side, fq)) + +def _cxxRefType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + t.ref = 1 + return t + +def _cxxConstRefType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and ipdltype.isActor(): + return t + if ipdltype.isIPDL() and ipdltype.isShmem(): + t.ref = 1 + return t + t.const = 1 + t.ref = 1 + return t + +def _cxxMoveRefType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and (ipdltype.isArray() or ipdltype.isShmem() or ipdltype.isEndpoint()): + t.ref = 2 + return t + return _cxxConstRefType(ipdltype, side) + +def _cxxPtrToType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and ipdltype.isActor(): + t.ptr = 0 + t.ptrptr = 1 + return t + t.ptr = 1 + return t + +def _cxxConstPtrToType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and ipdltype.isActor(): + t.ptr = 0 + t.ptrconstptr = 1 + return t + t.const = 1 + t.ptr = 1 + return t + +def _allocMethod(ptype, side): + return ExprVar('Alloc'+ str(Actor(ptype, side))) + +def _deallocMethod(ptype, side): + return ExprVar('Dealloc'+ str(Actor(ptype, side))) + +## +## A _HybridDecl straddles IPDL and C++ decls. It knows which C++ +## types correspond to which IPDL types, and it also knows how +## serialize and deserialize "special" IPDL C++ types. +## +class _HybridDecl: + """A hybrid decl stores both an IPDL type and all the C++ type +info needed by later passes, along with a basic name for the decl.""" + def __init__(self, ipdltype, name): + self.ipdltype = ipdltype + self.name = name + self.idnum = 0 + + def var(self): + return ExprVar(self.name) + + def bareType(self, side): + """Return this decl's unqualified C++ type.""" + return _cxxBareType(self.ipdltype, side) + + def refType(self, side): + """Return this decl's C++ type as a 'reference' type, which is not +necessarily a C++ reference.""" + return _cxxRefType(self.ipdltype, side) + + def constRefType(self, side): + """Return this decl's C++ type as a const, 'reference' type.""" + return _cxxConstRefType(self.ipdltype, side) + + def rvalueRefType(self, side): + """Return this decl's C++ type as an r-value 'reference' type.""" + return _cxxMoveRefType(self.ipdltype, side) + + def ptrToType(self, side): + return _cxxPtrToType(self.ipdltype, side) + + def constPtrToType(self, side): + return _cxxConstPtrToType(self.ipdltype, side) + + def inType(self, side): + """Return this decl's C++ Type with inparam semantics.""" + if self.ipdltype.isIPDL() and self.ipdltype.isActor(): + return self.bareType(side) + return self.constRefType(side) + + def moveType(self, side): + """Return this decl's C++ Type with move semantics.""" + if self.ipdltype.isIPDL() and self.ipdltype.isActor(): + return self.bareType(side) + return self.rvalueRefType(side); + + def outType(self, side): + """Return this decl's C++ Type with outparam semantics.""" + t = self.bareType(side) + if self.ipdltype.isIPDL() and self.ipdltype.isActor(): + t.ptr = 0; t.ptrptr = 1 + return t + t.ptr = 1 + return t + +##-------------------------------------------------- + +class HasFQName: + def fqClassName(self): + return self.decl.type.fullname() + +class _CompoundTypeComponent(_HybridDecl): + def __init__(self, ipdltype, name, side, ct): + _HybridDecl.__init__(self, ipdltype, name) + self.side = side + self.special = _hasVisibleActor(ipdltype) + self.recursive = ct.decl.type.mutuallyRecursiveWith(ipdltype) + + def internalType(self): + if self.recursive: + return self.ptrToType() + else: + return self.bareType() + + # @override the following methods to pass |self.side| instead of + # forcing the caller to remember which side we're declared to + # represent. + def bareType(self, side=None): + return _HybridDecl.bareType(self, self.side) + def refType(self, side=None): + return _HybridDecl.refType(self, self.side) + def constRefType(self, side=None): + return _HybridDecl.constRefType(self, self.side) + def ptrToType(self, side=None): + return _HybridDecl.ptrToType(self, self.side) + def constPtrToType(self, side=None): + return _HybridDecl.constPtrToType(self, self.side) + def inType(self, side=None): + return _HybridDecl.inType(self, self.side) + + +class StructDecl(ipdl.ast.StructDecl, HasFQName): + @staticmethod + def upgrade(structDecl): + assert isinstance(structDecl, ipdl.ast.StructDecl) + structDecl.__class__ = StructDecl + return structDecl + +class _StructField(_CompoundTypeComponent): + def __init__(self, ipdltype, name, sd, side=None): + self.basename = name + fname = name + special = _hasVisibleActor(ipdltype) + if special: + fname += side.title() + + _CompoundTypeComponent.__init__(self, ipdltype, fname, side, sd) + + def getMethod(self, thisexpr=None, sel='.'): + meth = self.var() + if thisexpr is not None: + return ExprSelect(thisexpr, sel, meth.name) + return meth + + def initExpr(self, thisexpr): + expr = ExprCall(self.getMethod(thisexpr=thisexpr)) + if self.ipdltype.isIPDL() and self.ipdltype.isActor(): + expr = ExprCast(expr, self.bareType(), const=1) + return expr + + def refExpr(self, thisexpr=None): + ref = self.memberVar() + if thisexpr is not None: + ref = ExprSelect(thisexpr, '.', ref.name) + if self.recursive: + ref = ExprDeref(ref) + return ref + + def constRefExpr(self, thisexpr=None): + # sigh, gross hack + refexpr = self.refExpr(thisexpr) + if 'Shmem' == self.ipdltype.name(): + refexpr = ExprCast(refexpr, Type('Shmem', ref=1), const=1) + if 'FileDescriptor' == self.ipdltype.name(): + refexpr = ExprCast(refexpr, Type('FileDescriptor', ref=1), const=1) + return refexpr + + def argVar(self): + return ExprVar('_'+ self.name) + + def memberVar(self): + return ExprVar(self.name + '_') + + def initStmts(self): + if self.recursive: + return [ StmtExpr(ExprAssn(self.memberVar(), + ExprNew(self.bareType()))) ] + elif self.ipdltype.isIPDL() and self.ipdltype.isActor(): + return [ StmtExpr(ExprAssn(self.memberVar(), + ExprLiteral.NULL)) ] + else: + return [] + + def destructStmts(self): + if self.recursive: + return [ StmtExpr(ExprDelete(self.memberVar())) ] + else: + return [] + + +class UnionDecl(ipdl.ast.UnionDecl, HasFQName): + def callType(self, var=None): + func = ExprVar('type') + if var is not None: + func = ExprSelect(var, '.', func.name) + return ExprCall(func) + + @staticmethod + def upgrade(unionDecl): + assert isinstance(unionDecl, ipdl.ast.UnionDecl) + unionDecl.__class__ = UnionDecl + return unionDecl + + +class _UnionMember(_CompoundTypeComponent): + """Not in the AFL sense, but rather a member (e.g. |int;|) of an +IPDL union type.""" + def __init__(self, ipdltype, ud, side=None, other=None): + flatname = _flatTypeName(ipdltype) + special = _hasVisibleActor(ipdltype) + if special: + flatname += side.title() + + _CompoundTypeComponent.__init__(self, ipdltype, 'V'+ flatname, side, ud) + self.flattypename = flatname + if special: + if other is not None: + self.other = other + else: + self.other = _UnionMember(ipdltype, ud, _otherSide(side), self) + + def enum(self): + return 'T' + self.flattypename + + def pqEnum(self): + return self.ud.name +'::'+ self.enum() + + def enumvar(self): + return ExprVar(self.enum()) + + def unionType(self): + """Type used for storage in generated C union decl.""" + if self.recursive: + return self.ptrToType() + else: + return Type('mozilla::AlignedStorage2', T=self.internalType()) + + def unionValue(self): + # NB: knows that Union's storage C union is named |mValue| + return ExprSelect(ExprVar('mValue'), '.', self.name) + + def typedef(self): + return self.flattypename +'__tdef' + + def callGetConstPtr(self): + """Return an expression of type self.constptrToSelfType()""" + return ExprCall(ExprVar(self.getConstPtrName())) + + def callGetPtr(self): + """Return an expression of type self.ptrToSelfType()""" + return ExprCall(ExprVar(self.getPtrName())) + + def callOperatorEq(self, rhs): + if self.ipdltype.isIPDL() and self.ipdltype.isActor(): + rhs = ExprCast(rhs, self.bareType(), const=1) + return ExprAssn(ExprDeref(self.callGetPtr()), rhs) + + def callCtor(self, expr=None): + assert not isinstance(expr, list) + + if expr is None: + args = None + elif self.ipdltype.isIPDL() and self.ipdltype.isActor(): + args = [ ExprCast(expr, self.bareType(), const=1) ] + else: + args = [ expr ] + + if self.recursive: + return ExprAssn(self.callGetPtr(), + ExprNew(self.bareType(self.side), + args=args)) + else: + return ExprNew(self.bareType(self.side), + args=args, + newargs=[ ExprVar('mozilla::KnownNotNull'), self.callGetPtr() ]) + + def callDtor(self): + if self.recursive: + return ExprDelete(self.callGetPtr()) + else: + return ExprCall( + ExprSelect(self.callGetPtr(), '->', '~'+ self.typedef())) + + def getTypeName(self): return 'get_'+ self.flattypename + def getConstTypeName(self): return 'get_'+ self.flattypename + + def getOtherTypeName(self): return 'get_'+ self.otherflattypename + + def getPtrName(self): return 'ptr_'+ self.flattypename + def getConstPtrName(self): return 'constptr_'+ self.flattypename + + def ptrToSelfExpr(self): + """|*ptrToSelfExpr()| has type |self.bareType()|""" + v = self.unionValue() + if self.recursive: + return v + else: + return ExprCall(ExprSelect(v, '.', 'addr')) + + def constptrToSelfExpr(self): + """|*constptrToSelfExpr()| has type |self.constType()|""" + v = self.unionValue() + if self.recursive: + return v + return ExprCall(ExprSelect(v, '.', 'addr')) + + def ptrToInternalType(self): + t = self.ptrToType() + if self.recursive: + t.ref = 1 + return t + + def defaultValue(self): + # Use the default constructor for any class that does not have an + # implicit copy constructor. + if not self.bareType().hasimplicitcopyctor: + return None + + if self.ipdltype.isIPDL() and self.ipdltype.isActor(): + return ExprLiteral.NULL + # XXX sneaky here, maybe need ExprCtor()? + return ExprCall(self.bareType()) + + def getConstValue(self): + v = ExprDeref(self.callGetConstPtr()) + # sigh + if 'Shmem' == self.ipdltype.name(): + v = ExprCast(v, Type('Shmem', ref=1), const=1) + if 'FileDescriptor' == self.ipdltype.name(): + v = ExprCast(v, Type('FileDescriptor', ref=1), const=1) + return v + +##-------------------------------------------------- + +class MessageDecl(ipdl.ast.MessageDecl): + def baseName(self): + return self.name + + def recvMethod(self): + name = _recvPrefix(self.decl.type) + self.baseName() + if self.decl.type.isCtor(): + name += 'Constructor' + return ExprVar(name) + + def sendMethod(self): + name = _sendPrefix(self.decl.type) + self.baseName() + if self.decl.type.isCtor(): + name += 'Constructor' + return ExprVar(name) + + def hasReply(self): + return (self.decl.type.hasReply() + or self.decl.type.isCtor() + or self.decl.type.isDtor()) + + def msgCtorFunc(self): + return 'Msg_%s'% (self.decl.progname) + + def prettyMsgName(self, pfx=''): + return pfx + self.msgCtorFunc() + + def pqMsgCtorFunc(self): + return '%s::%s'% (self.namespace, self.msgCtorFunc()) + + def msgId(self): return self.msgCtorFunc()+ '__ID' + def pqMsgId(self): + return '%s::%s'% (self.namespace, self.msgId()) + + def replyCtorFunc(self): + return 'Reply_%s'% (self.decl.progname) + + def pqReplyCtorFunc(self): + return '%s::%s'% (self.namespace, self.replyCtorFunc()) + + def replyId(self): return self.replyCtorFunc()+ '__ID' + def pqReplyId(self): + return '%s::%s'% (self.namespace, self.replyId()) + + def prettyReplyName(self, pfx=''): + return pfx + self.replyCtorFunc() + + def actorDecl(self): + return self.params[0] + + def makeCxxParams(self, paramsems='in', returnsems='out', + side=None, implicit=1): + """Return a list of C++ decls per the spec'd configuration. +|params| and |returns| is the C++ semantics of those: 'in', 'out', or None.""" + + def makeDecl(d, sems): + if sems is 'in': + return Decl(d.inType(side), d.name) + elif sems is 'move': + return Decl(d.moveType(side), d.name) + elif sems is 'out': + return Decl(d.outType(side), d.name) + else: assert 0 + + cxxparams = [ ] + if paramsems is not None: + cxxparams.extend([ makeDecl(d, paramsems) for d in self.params ]) + + if returnsems is not None: + cxxparams.extend([ makeDecl(r, returnsems) for r in self.returns ]) + + if not implicit and self.decl.type.hasImplicitActorParam(): + cxxparams = cxxparams[1:] + + return cxxparams + + def makeCxxArgs(self, paramsems='in', retsems='out', retcallsems='out', + implicit=1): + assert not retcallsems or retsems # retcallsems => returnsems + cxxargs = [ ] + + if paramsems is 'move': + cxxargs.extend([ ExprMove(p.var()) for p in self.params ]) + elif paramsems is 'in': + cxxargs.extend([ p.var() for p in self.params ]) + else: + assert False + + for ret in self.returns: + if retsems is 'in': + if retcallsems is 'in': + cxxargs.append(ret.var()) + elif retcallsems is 'out': + cxxargs.append(ExprAddrOf(ret.var())) + else: assert 0 + elif retsems is 'out': + if retcallsems is 'in': + cxxargs.append(ExprDeref(ret.var())) + elif retcallsems is 'out': + cxxargs.append(ret.var()) + else: assert 0 + + if not implicit: + assert self.decl.type.hasImplicitActorParam() + cxxargs = cxxargs[1:] + + return cxxargs + + + @staticmethod + def upgrade(messageDecl): + assert isinstance(messageDecl, ipdl.ast.MessageDecl) + if messageDecl.decl.type.hasImplicitActorParam(): + messageDecl.params.insert( + 0, + _HybridDecl( + ipdl.type.ActorType( + messageDecl.decl.type.constructedType()), + 'actor')) + messageDecl.__class__ = MessageDecl + return messageDecl + +##-------------------------------------------------- +def _semsToChannelParts(sems): + return [ 'mozilla', 'ipc', 'MessageChannel' ] + +def _usesShmem(p): + for md in p.messageDecls: + for param in md.inParams: + if ipdl.type.hasshmem(param.type): + return True + for ret in md.outParams: + if ipdl.type.hasshmem(ret.type): + return True + return False + +def _subtreeUsesShmem(p): + if _usesShmem(p): + return True + + ptype = p.decl.type + for mgd in ptype.manages: + if ptype is not mgd: + if _subtreeUsesShmem(mgd._ast): + return True + return False + +class Protocol(ipdl.ast.Protocol): + def cxxTypedefs(self): + return self.decl.cxxtypedefs + + def sendSems(self): + return self.decl.type.toplevel().sendSemantics + + def channelName(self): + return '::'.join(_semsToChannelParts(self.sendSems())) + + def channelSel(self): + if self.decl.type.isToplevel(): return '.' + return '->' + + def channelType(self): + return Type('Channel', ptr=not self.decl.type.isToplevel()) + + def channelHeaderFile(self): + return '/'.join(_semsToChannelParts(self.sendSems())) +'.h' + + def managerInterfaceType(self, ptr=0): + return Type('mozilla::ipc::IProtocol', ptr=ptr) + + def openedProtocolInterfaceType(self, ptr=0): + return Type('mozilla::ipc::IToplevelProtocol', + ptr=ptr) + + def _ipdlmgrtype(self): + assert 1 == len(self.decl.type.managers) + for mgr in self.decl.type.managers: return mgr + + def managerActorType(self, side, ptr=0): + return Type(_actorName(self._ipdlmgrtype().name(), side), + ptr=ptr) + + def stateMethod(self): + return ExprVar('state'); + + def registerMethod(self): + return ExprVar('Register') + + def registerIDMethod(self): + return ExprVar('RegisterID') + + def lookupIDMethod(self): + return ExprVar('Lookup') + + def unregisterMethod(self, actorThis=None): + if actorThis is not None: + return ExprSelect(actorThis, '->', 'Unregister') + return ExprVar('Unregister') + + def removeManageeMethod(self): + return ExprVar('RemoveManagee') + + def createSharedMemory(self): + return ExprVar('CreateSharedMemory') + + def lookupSharedMemory(self): + return ExprVar('LookupSharedMemory') + + def isTrackingSharedMemory(self): + return ExprVar('IsTrackingSharedMemory') + + def destroySharedMemory(self): + return ExprVar('DestroySharedMemory') + + def otherPidMethod(self): + return ExprVar('OtherPid') + + def callOtherPid(self, actorThis=None): + fn = self.otherPidMethod() + if actorThis is not None: + fn = ExprSelect(actorThis, '->', fn.name) + return ExprCall(fn) + + def getChannelMethod(self): + return ExprVar('GetIPCChannel') + + def callGetChannel(self, actorThis=None): + fn = self.getChannelMethod() + if actorThis is not None: + fn = ExprSelect(actorThis, '->', fn.name) + return ExprCall(fn) + + def processingErrorVar(self): + assert self.decl.type.isToplevel() + return ExprVar('ProcessingError') + + def shouldContinueFromTimeoutVar(self): + assert self.decl.type.isToplevel() + return ExprVar('ShouldContinueFromReplyTimeout') + + def enteredCxxStackVar(self): + assert self.decl.type.isToplevel() + return ExprVar('EnteredCxxStack') + + def exitedCxxStackVar(self): + assert self.decl.type.isToplevel() + return ExprVar('ExitedCxxStack') + + def enteredCallVar(self): + assert self.decl.type.isToplevel() + return ExprVar('EnteredCall') + + def exitedCallVar(self): + assert self.decl.type.isToplevel() + return ExprVar('ExitedCall') + + def onCxxStackVar(self): + assert self.decl.type.isToplevel() + return ExprVar('IsOnCxxStack') + + # an actor's C++ private variables + def channelVar(self, actorThis=None): + if actorThis is not None: + return ExprSelect(actorThis, '->', 'mChannel') + return ExprVar('mChannel') + + def routingId(self, actorThis=None): + if self.decl.type.isToplevel(): + return ExprVar('MSG_ROUTING_CONTROL') + if actorThis is not None: + return ExprCall(ExprSelect(actorThis, '->', 'Id')) + return ExprCall(ExprVar('Id')) + + def stateVar(self, actorThis=None): + if actorThis is not None: + return ExprSelect(actorThis, '->', 'mState') + return ExprVar('mState') + + def fqStateType(self): + return Type(self.decl.type.name() +'::State') + + def startState(self): + return _startState(self.decl.type) + + def nullState(self): + return _nullState(self.decl.type) + + def deadState(self): + return _deadState(self.decl.type) + + def managerVar(self, thisexpr=None): + assert thisexpr is not None or not self.decl.type.isToplevel() + mvar = ExprCall(ExprVar('Manager'), args=[]) + if thisexpr is not None: + mvar = ExprCall(ExprSelect(thisexpr, '->', 'Manager'), args=[]) + return mvar + + def managedCxxType(self, actortype, side): + assert self.decl.type.isManagerOf(actortype) + return Type(_actorName(actortype.name(), side), ptr=1) + + def managedMethod(self, actortype, side): + assert self.decl.type.isManagerOf(actortype) + return ExprVar('Managed'+ _actorName(actortype.name(), side)) + + def managedVar(self, actortype, side): + assert self.decl.type.isManagerOf(actortype) + return ExprVar('mManaged'+ _actorName(actortype.name(), side)) + + def managedVarType(self, actortype, side, const=0, ref=0): + assert self.decl.type.isManagerOf(actortype) + return _cxxManagedContainerType(Type(_actorName(actortype.name(), side)), + const=const, ref=ref) + + # XXX this is sucky, fix + def usesShmem(self): + return _usesShmem(self) + + def subtreeUsesShmem(self): + return _subtreeUsesShmem(self) + + @staticmethod + def upgrade(protocol): + assert isinstance(protocol, ipdl.ast.Protocol) + protocol.__class__ = Protocol + return protocol + + +class TranslationUnit(ipdl.ast.TranslationUnit): + @staticmethod + def upgrade(tu): + assert isinstance(tu, ipdl.ast.TranslationUnit) + tu.__class__ = TranslationUnit + return tu + +##----------------------------------------------------------------------------- + +class _DecorateWithCxxStuff(ipdl.ast.Visitor): + """Phase 1 of lowering: decorate the IPDL AST with information +relevant to C++ code generation. + +This pass results in an AST that is a poor man's "IR"; in reality, a +"hybrid" AST mainly consisting of IPDL nodes with new C++ info along +with some new IPDL/C++ nodes that are tuned for C++ codegen.""" + + def __init__(self): + self.visitedTus = set() + # the set of typedefs that allow generated classes to + # reference known C++ types by their "short name" rather than + # fully-qualified name. e.g. |Foo| rather than |a::b::Foo|. + self.typedefs = [ ] + self.typedefSet = set([ Typedef(Type('mozilla::ipc::ActorHandle'), + 'ActorHandle'), + Typedef(Type('base::ProcessId'), + 'ProcessId'), + Typedef(Type('mozilla::ipc::ProtocolId'), + 'ProtocolId'), + Typedef(Type('mozilla::ipc::Transport'), + 'Transport'), + Typedef(Type('mozilla::ipc::Endpoint'), + 'Endpoint', ['FooSide']), + Typedef(Type('mozilla::ipc::TransportDescriptor'), + 'TransportDescriptor') ]) + self.protocolName = None + + def visitTranslationUnit(self, tu): + if tu not in self.visitedTus: + self.visitedTus.add(tu) + ipdl.ast.Visitor.visitTranslationUnit(self, tu) + if not isinstance(tu, TranslationUnit): + TranslationUnit.upgrade(tu) + self.typedefs[:] = sorted(list(self.typedefSet)) + + def visitInclude(self, inc): + if inc.tu.filetype == 'header': + inc.tu.accept(self) + + def visitProtocol(self, pro): + self.protocolName = pro.name + pro.decl.cxxtypedefs = self.typedefs + Protocol.upgrade(pro) + return ipdl.ast.Visitor.visitProtocol(self, pro) + + + def visitUsingStmt(self, using): + if using.decl.fullname is not None: + self.typedefSet.add(Typedef(Type(using.decl.fullname), + using.decl.shortname)) + + def visitStructDecl(self, sd): + if not isinstance(sd, StructDecl): + sd.decl.special = 0 + newfields = [ ] + for f in sd.fields: + ftype = f.decl.type + if _hasVisibleActor(ftype): + sd.decl.special = 1 + # if ftype has a visible actor, we need both + # |ActorParent| and |ActorChild| fields + newfields.append(_StructField(ftype, f.name, sd, + side='parent')) + newfields.append(_StructField(ftype, f.name, sd, + side='child')) + else: + newfields.append(_StructField(ftype, f.name, sd)) + sd.fields = newfields + StructDecl.upgrade(sd) + + if sd.decl.fullname is not None: + self.typedefSet.add(Typedef(Type(sd.fqClassName()), sd.name)) + + + def visitUnionDecl(self, ud): + ud.decl.special = 0 + newcomponents = [ ] + for ctype in ud.decl.type.components: + if _hasVisibleActor(ctype): + ud.decl.special = 1 + # if ctype has a visible actor, we need both + # |ActorParent| and |ActorChild| union members + newcomponents.append(_UnionMember(ctype, ud, side='parent')) + newcomponents.append(_UnionMember(ctype, ud, side='child')) + else: + newcomponents.append(_UnionMember(ctype, ud)) + ud.components = newcomponents + UnionDecl.upgrade(ud) + + if ud.decl.fullname is not None: + self.typedefSet.add(Typedef(Type(ud.fqClassName()), ud.name)) + + + def visitDecl(self, decl): + return _HybridDecl(decl.type, decl.progname) + + def visitMessageDecl(self, md): + md.namespace = self.protocolName + md.params = [ param.accept(self) for param in md.inParams ] + md.returns = [ ret.accept(self) for ret in md.outParams ] + MessageDecl.upgrade(md) + + def visitTransitionStmt(self, ts): + name = ts.state.decl.progname + ts.state.decl.cxxname = name + ts.state.decl.cxxenum = ExprVar(self.protocolName +'::'+ name) + +##----------------------------------------------------------------------------- + +def msgenums(protocol, pretty=False): + msgenum = TypeEnum('MessageType') + msgstart = _messageStartName(protocol.decl.type) +' << 16' + msgenum.addId(protocol.name + 'Start', msgstart) + + for md in protocol.messageDecls: + msgenum.addId(md.prettyMsgName() if pretty else md.msgId()) + if md.hasReply(): + msgenum.addId(md.prettyReplyName() if pretty else md.replyId()) + + msgenum.addId(protocol.name +'End') + return msgenum + +class _GenerateProtocolCode(ipdl.ast.Visitor): + '''Creates code common to both the parent and child actors.''' + def __init__(self): + self.protocol = None # protocol we're generating a class for + self.hdrfile = None # what will become Protocol.h + self.cppfile = None # what will become Protocol.cpp + self.cppIncludeHeaders = [] + self.structUnionDefns = [] + self.funcDefns = [] + + def lower(self, tu, cxxHeaderFile, cxxFile): + self.protocol = tu.protocol + self.hdrfile = cxxHeaderFile + self.cppfile = cxxFile + tu.accept(self) + + def visitTranslationUnit(self, tu): + hf = self.hdrfile + + hf.addthing(_DISCLAIMER) + hf.addthings(_includeGuardStart(hf)) + hf.addthing(Whitespace.NL) + + for inc in builtinHeaderIncludes: + self.visitBuiltinCxxInclude(inc) + + # Compute the set of includes we need for declared structure/union + # classes for this protocol. + typesToIncludes = {} + for using in tu.using: + typestr = str(using.type.spec) + assert typestr not in typesToIncludes + typesToIncludes[typestr] = using.header + + aggregateTypeIncludes = set() + for su in tu.structsAndUnions: + typedeps = _ComputeTypeDeps(su.decl.type, True) + if isinstance(su, ipdl.ast.StructDecl): + for f in su.fields: + f.ipdltype.accept(typedeps) + elif isinstance(su, ipdl.ast.UnionDecl): + for c in su.components: + c.ipdltype.accept(typedeps) + + for typename in [t.fromtype.name for t in typedeps.usingTypedefs]: + if typename in typesToIncludes: + aggregateTypeIncludes.add(typesToIncludes[typename]) + + if len(aggregateTypeIncludes) != 0: + hf.addthing(Whitespace.NL) + hf.addthings([ Whitespace("// Headers for typedefs"), Whitespace.NL ]) + + for headername in sorted(iter(aggregateTypeIncludes)): + hf.addthing(CppDirective('include', '"' + headername + '"')) + + ipdl.ast.Visitor.visitTranslationUnit(self, tu) + if tu.filetype == 'header': + self.cppIncludeHeaders.append(_ipdlhHeaderName(tu)) + + hf.addthing(Whitespace.NL) + hf.addthings(_includeGuardEnd(hf)) + + cf = self.cppfile + cf.addthings(( + [ _DISCLAIMER, Whitespace.NL ] + + [ CppDirective('include','"'+h+'.h"') + for h in self.cppIncludeHeaders ] + + [ Whitespace.NL ] + )) + + if self.protocol: + # construct the namespace into which we'll stick all our defns + ns = Namespace(self.protocol.name) + cf.addthing(_putInNamespaces(ns, self.protocol.namespaces)) + ns.addstmts(([ Whitespace.NL] + + self.funcDefns + +[ Whitespace.NL ])) + + cf.addthings(self.structUnionDefns) + + + def visitBuiltinCxxInclude(self, inc): + self.hdrfile.addthing(CppDirective('include', '"'+ inc.file +'"')) + + def visitInclude(self, inc): + if inc.tu.filetype == 'header': + self.hdrfile.addthing(CppDirective( + 'include', '"'+ _ipdlhHeaderName(inc.tu) +'.h"')) + + def processStructOrUnionClass(self, su, which, forwarddecls, cls): + clsdecl, methoddefns = _splitClassDeclDefn(cls) + + self.hdrfile.addthings( + [ Whitespace.NL ] + + forwarddecls + + [ Whitespace(""" +//----------------------------------------------------------------------------- +// Declaration of the IPDL type |%s %s| +// +"""% (which, su.name)), + _putInNamespaces(clsdecl, su.namespaces), + ]) + + self.structUnionDefns.extend([ + Whitespace(""" +//----------------------------------------------------------------------------- +// Method definitions for the IPDL type |%s %s| +// +"""% (which, su.name)), + _putInNamespaces(methoddefns, su.namespaces), + ]) + + def visitStructDecl(self, sd): + return self.processStructOrUnionClass(sd, 'struct', + *_generateCxxStruct(sd)) + + def visitUnionDecl(self, ud): + return self.processStructOrUnionClass(ud, 'union', + *_generateCxxUnion(ud)) + + def visitProtocol(self, p): + self.cppIncludeHeaders.append(_protocolHeaderName(self.protocol, '')) + + # Forward declare our own actors. + self.hdrfile.addthings([ + Whitespace.NL, + _makeForwardDeclForActor(p.decl.type, 'Parent'), + _makeForwardDeclForActor(p.decl.type, 'Child') + ]) + + bridges = ProcessGraph.bridgesOf(p.decl.type) + for bridge in bridges: + ppt, pside = bridge.parent.ptype, _otherSide(bridge.parent.side) + cpt, cside = bridge.child.ptype, _otherSide(bridge.child.side) + self.hdrfile.addthings([ + Whitespace.NL, + _makeForwardDeclForActor(ppt, pside), + _makeForwardDeclForActor(cpt, cside) + ]) + self.cppIncludeHeaders.append(_protocolHeaderName(ppt._ast, pside)) + self.cppIncludeHeaders.append(_protocolHeaderName(cpt._ast, cside)) + + opens = ProcessGraph.opensOf(p.decl.type) + for o in opens: + optype, oside = o.opener.ptype, o.opener.side + self.hdrfile.addthings([ + Whitespace.NL, + _makeForwardDeclForActor(optype, oside) + ]) + self.cppIncludeHeaders.append(_protocolHeaderName(optype._ast, oside)) + + self.hdrfile.addthing(Whitespace(""" +//----------------------------------------------------------------------------- +// Code common to %sChild and %sParent +// +"""% (p.name, p.name))) + + # construct the namespace into which we'll stick all our decls + ns = Namespace(self.protocol.name) + self.hdrfile.addthing(_putInNamespaces(ns, p.namespaces)) + ns.addstmt(Whitespace.NL) + + # user-facing methods for connecting two process with a new channel + for bridge in bridges: + bdecl, bdefn = _splitFuncDeclDefn(self.genBridgeFunc(bridge)) + ns.addstmts([ bdecl, Whitespace.NL ]) + self.funcDefns.append(bdefn) + + # user-facing methods for opening a new channel across two + # existing endpoints + for o in opens: + odecl, odefn = _splitFuncDeclDefn(self.genOpenFunc(o)) + ns.addstmts([ odecl, Whitespace.NL ]) + self.funcDefns.append(odefn) + + edecl, edefn = _splitFuncDeclDefn(self.genEndpointFunc()) + ns.addstmts([ edecl, Whitespace.NL ]) + self.funcDefns.append(edefn) + + # state information + stateenum = TypeEnum('State') + # NB: __Dead is the first state on purpose, so that it has + # value '0' + stateenum.addId(_deadState().name) + stateenum.addId(_nullState().name) + stateenum.addId(_errorState().name) + stateenum.addId(_dyingState().name) + for ts in p.transitionStmts: + stateenum.addId(ts.state.decl.cxxname) + if len(p.transitionStmts): + startstate = p.transitionStmts[0].state.decl.cxxname + else: + startstate = _nullState().name + stateenum.addId(_startState().name, startstate) + + ns.addstmts([ StmtDecl(Decl(stateenum,'')), Whitespace.NL ]) + + # spit out message type enum and classes + msgenum = msgenums(self.protocol) + ns.addstmts([ StmtDecl(Decl(msgenum, '')), Whitespace.NL ]) + + tfDecl, tfDefn = _splitFuncDeclDefn(self.genTransitionFunc()) + ns.addstmts([ tfDecl, Whitespace.NL ]) + self.funcDefns.append(tfDefn) + + for md in p.messageDecls: + decls = [] + + mfDecl, mfDefn = _splitFuncDeclDefn( + _generateMessageConstructor(md.msgCtorFunc(), md.msgId(), + md.decl.type.nested, + md.decl.type.prio, + md.prettyMsgName(p.name+'::'), + md.decl.type.compress)) + decls.append(mfDecl) + self.funcDefns.append(mfDefn) + + if md.hasReply(): + rfDecl, rfDefn = _splitFuncDeclDefn( + _generateMessageConstructor( + md.replyCtorFunc(), md.replyId(), + md.decl.type.nested, + md.decl.type.prio, + md.prettyReplyName(p.name+'::'), + md.decl.type.compress)) + decls.append(rfDecl) + self.funcDefns.append(rfDefn) + + decls.append(Whitespace.NL) + ns.addstmts(decls) + + ns.addstmts([ Whitespace.NL, Whitespace.NL ]) + + + def genBridgeFunc(self, bridge): + p = self.protocol + parentHandleType = _cxxBareType(ActorType(bridge.parent.ptype), + _otherSide(bridge.parent.side), + fq=1) + parentvar = ExprVar('parentHandle') + + childHandleType = _cxxBareType(ActorType(bridge.child.ptype), + _otherSide(bridge.child.side), + fq=1) + childvar = ExprVar('childHandle') + + bridgefunc = MethodDefn(MethodDecl( + 'Bridge', + params=[ Decl(parentHandleType, parentvar.name), + Decl(childHandleType, childvar.name) ], + ret=Type.NSRESULT)) + bridgefunc.addstmt(StmtReturn(ExprCall( + ExprVar('mozilla::ipc::Bridge'), + args=[ _backstagePass(), + p.callGetChannel(parentvar), p.callOtherPid(parentvar), + p.callGetChannel(childvar), p.callOtherPid(childvar), + _protocolId(p.decl.type), + ExprVar(_messageStartName(p.decl.type) + 'Child') + ]))) + return bridgefunc + + + def genOpenFunc(self, o): + p = self.protocol + localside = o.opener.side + openertype = _cxxBareType(ActorType(o.opener.ptype), o.opener.side, + fq=1) + openervar = ExprVar('opener') + openfunc = MethodDefn(MethodDecl( + 'Open', + params=[ Decl(openertype, openervar.name) ], + ret=Type.BOOL)) + openfunc.addstmt(StmtReturn(ExprCall( + ExprVar('mozilla::ipc::Open'), + args=[ _backstagePass(), + p.callGetChannel(openervar), p.callOtherPid(openervar), + _sideToTransportMode(localside), + _protocolId(p.decl.type), + ExprVar(_messageStartName(p.decl.type) + 'Child') + ]))) + return openfunc + + + # Generate code for PFoo::CreateEndpoints. + def genEndpointFunc(self): + p = self.protocol.decl.type + tparent = _cxxBareType(ActorType(p), 'Parent', fq=1) + tchild = _cxxBareType(ActorType(p), 'Child', fq=1) + methodvar = ExprVar('CreateEndpoints') + rettype = Type.NSRESULT + parentpidvar = ExprVar('aParentDestPid') + childpidvar = ExprVar('aChildDestPid') + parentvar = ExprVar('aParent') + childvar = ExprVar('aChild') + + openfunc = MethodDefn(MethodDecl( + methodvar.name, + params=[ Decl(Type('base::ProcessId'), parentpidvar.name), + Decl(Type('base::ProcessId'), childpidvar.name), + Decl(Type('mozilla::ipc::Endpoint<' + tparent.name + '>', ptr=1), parentvar.name), + Decl(Type('mozilla::ipc::Endpoint<' + tchild.name + '>', ptr=1), childvar.name) ], + ret=rettype)) + openfunc.addstmt(StmtReturn(ExprCall( + ExprVar('mozilla::ipc::CreateEndpoints'), + args=[ _backstagePass(), + parentpidvar, childpidvar, + _protocolId(p), + ExprVar(_messageStartName(p) + 'Child'), + parentvar, childvar + ]))) + return openfunc + + + def genTransitionFunc(self): + ptype = self.protocol.decl.type + usesend, sendvar = set(), ExprVar('Send__') + userecv, recvvar = set(), ExprVar('Recv__') + + def sameTrigger(trigger, actionexpr): + if trigger is ipdl.ast.SEND or trigger is ipdl.ast.CALL: + usesend.add('yes') + return ExprBinary(sendvar, '==', actionexpr) + else: + userecv.add('yes') + return ExprBinary(recvvar, '==', + actionexpr) + + def stateEnum(s): + if s is ipdl.ast.State.DEAD: + return _deadState() + else: + return ExprVar(s.decl.cxxname) + + # bool Transition(Trigger trigger, State* next) + # The state we are transitioning from is stored in *next. + fromvar = ExprVar('from') + triggervar = ExprVar('trigger') + nextvar = ExprVar('next') + msgexpr = ExprSelect(triggervar, '.', 'mMessage') + actionexpr = ExprSelect(triggervar, '.', 'mAction') + + transitionfunc = FunctionDefn(FunctionDecl( + 'Transition', + params=[ Decl(Type('mozilla::ipc::Trigger'), triggervar.name), + Decl(Type('State', ptr=1), nextvar.name) ], + ret=Type.BOOL)) + + fromswitch = StmtSwitch(fromvar) + + for ts in self.protocol.transitionStmts: + msgswitch = StmtSwitch(msgexpr) + + msgToTransitions = { } + + for t in ts.transitions: + msgid = t.msg._md.msgId() + + ifsametrigger = StmtIf(sameTrigger(t.trigger, actionexpr)) + # FIXME multi-out states + for nextstate in t.toStates: break + ifsametrigger.addifstmts([ + StmtExpr(ExprAssn(ExprDeref(nextvar), + stateEnum(nextstate))), + StmtReturn(ExprLiteral.TRUE) + ]) + + transitions = msgToTransitions.get(msgid, [ ]) + transitions.append(ifsametrigger) + msgToTransitions[msgid] = transitions + + for msgid, transitions in msgToTransitions.iteritems(): + block = Block() + block.addstmts(transitions +[ StmtBreak() ]) + msgswitch.addcase(CaseLabel(msgid), block) + + msgblock = Block() + msgblock.addstmts([ + msgswitch, + StmtBreak() + ]) + fromswitch.addcase(CaseLabel(ts.state.decl.cxxname), msgblock) + + # special cases for Null and Error + nullerrorblock = Block() + if ptype.hasDelete: + ifdelete = StmtIf(ExprBinary(_deleteId(), '==', msgexpr)) + if ptype.hasReentrantDelete: + nextState = _dyingState() + else: + nextState = _deadState() + ifdelete.addifstmts([ + StmtExpr(ExprAssn(ExprDeref(nextvar), nextState)), + StmtReturn(ExprLiteral.TRUE) ]) + nullerrorblock.addstmt(ifdelete) + nullerrorblock.addstmt( + StmtReturn(ExprBinary(_nullState(), '==', fromvar))) + fromswitch.addfallthrough(CaseLabel(_nullState().name)) + fromswitch.addcase(CaseLabel(_errorState().name), nullerrorblock) + + # special case for Dead + deadblock = Block() + deadblock.addstmts([ + _logicError('__delete__()d actor'), + StmtReturn(ExprLiteral.FALSE) ]) + fromswitch.addcase(CaseLabel(_deadState().name), deadblock) + + # special case for Dying + dyingblock = Block() + if ptype.hasReentrantDelete: + ifdelete = StmtIf(ExprBinary(_deleteReplyId(), '==', msgexpr)) + ifdelete.addifstmt( + StmtExpr(ExprAssn(ExprDeref(nextvar), _deadState()))) + dyingblock.addstmt(ifdelete) + dyingblock.addstmt( + StmtReturn(ExprLiteral.TRUE)) + else: + dyingblock.addstmts([ + _logicError('__delete__()d (and unexpectedly dying) actor'), + StmtReturn(ExprLiteral.FALSE) ]) + fromswitch.addcase(CaseLabel(_dyingState().name), dyingblock) + + unreachedblock = Block() + unreachedblock.addstmts([ + _logicError('corrupted actor state'), + StmtReturn(ExprLiteral.FALSE) ]) + fromswitch.addcase(DefaultLabel(), unreachedblock) + + if usesend: + transitionfunc.addstmt( + StmtDecl(Decl(Type('int32_t', const=1), sendvar.name), + init=ExprVar('mozilla::ipc::Trigger::Send'))) + if userecv: + transitionfunc.addstmt( + StmtDecl(Decl(Type('int32_t', const=1), recvvar.name), + init=ExprVar('mozilla::ipc::Trigger::Recv'))) + if usesend or userecv: + transitionfunc.addstmt(Whitespace.NL) + + transitionfunc.addstmt(StmtDecl(Decl(Type('State'), fromvar.name), + init=ExprDeref(nextvar))) + transitionfunc.addstmt(fromswitch) + # all --> Error transitions break to here. But only insert this + # block if there is any possibility of such transitions. + if self.protocol.transitionStmts: + transitionfunc.addstmts([ + StmtExpr(ExprAssn(ExprDeref(nextvar), _errorState())), + StmtReturn(ExprLiteral.FALSE), + ]) + + return transitionfunc + +##-------------------------------------------------- + +def _generateMessageConstructor(clsname, msgid, nested, prio, prettyName, compress): + routingId = ExprVar('routingId') + + func = FunctionDefn(FunctionDecl( + clsname, + params=[ Decl(Type('int32_t'), routingId.name) ], + ret=Type('IPC::Message', ptr=1))) + + if compress == 'compress': + compression = ExprVar('IPC::Message::COMPRESSION_ENABLED') + elif compress: + assert compress == 'compressall' + compression = ExprVar('IPC::Message::COMPRESSION_ALL') + else: + compression = ExprVar('IPC::Message::COMPRESSION_NONE') + + if nested == ipdl.ast.NOT_NESTED: + nestedEnum = 'IPC::Message::NOT_NESTED' + elif nested == ipdl.ast.INSIDE_SYNC_NESTED: + nestedEnum = 'IPC::Message::NESTED_INSIDE_SYNC' + else: + assert nested == ipdl.ast.INSIDE_CPOW_NESTED + nestedEnum = 'IPC::Message::NESTED_INSIDE_CPOW' + + if prio == ipdl.ast.NORMAL_PRIORITY: + prioEnum = 'IPC::Message::NORMAL_PRIORITY' + else: + assert prio == ipdl.ast.HIGH_PRIORITY + prioEnum = 'IPC::Message::HIGH_PRIORITY' + + func.addstmt( + StmtReturn(ExprNew(Type('IPC::Message'), + args=[ routingId, + ExprVar(msgid), + ExprVar(nestedEnum), + ExprVar(prioEnum), + compression, + ExprLiteral.String(prettyName) ]))) + + return func + +##-------------------------------------------------- + +class _ComputeTypeDeps(TypeVisitor): + '''Pass that gathers the C++ types that a particular IPDL type +(recursively) depends on. There are two kinds of dependencies: (i) +types that need forward declaration; (ii) types that need a |using| +stmt. Some types generate both kinds.''' + + def __init__(self, fortype, unqualifiedTypedefs=False): + ipdl.type.TypeVisitor.__init__(self) + self.usingTypedefs = [ ] + self.forwardDeclStmts = [ ] + self.fortype = fortype + self.unqualifiedTypedefs = unqualifiedTypedefs + + def maybeTypedef(self, fqname, name): + if fqname != name or self.unqualifiedTypedefs: + self.usingTypedefs.append(Typedef(Type(fqname), name)) + + def visitBuiltinCxxType(self, t): + if t in self.visited: return + self.visited.add(t) + self.maybeTypedef(t.fullname(), t.name()) + + def visitImportedCxxType(self, t): + if t in self.visited: return + self.visited.add(t) + self.maybeTypedef(t.fullname(), t.name()) + + def visitActorType(self, t): + if t in self.visited: return + self.visited.add(t) + + fqname, name = t.fullname(), t.name() + + self.maybeTypedef(_actorName(fqname, 'Parent'), + _actorName(name, 'Parent')) + self.maybeTypedef(_actorName(fqname, 'Child'), + _actorName(name, 'Child')) + + self.forwardDeclStmts.extend([ + _makeForwardDeclForActor(t.protocol, 'parent'), Whitespace.NL, + _makeForwardDeclForActor(t.protocol, 'child'), Whitespace.NL + ]) + + def visitStructOrUnionType(self, su, defaultVisit): + if su in self.visited or su == self.fortype: return + self.visited.add(su) + self.maybeTypedef(su.fullname(), su.name()) + + if su.mutuallyRecursiveWith(self.fortype): + self.forwardDeclStmts.append(_makeForwardDecl(su)) + + return defaultVisit(self, su) + + def visitStructType(self, t): + return self.visitStructOrUnionType(t, TypeVisitor.visitStructType) + + def visitUnionType(self, t): + return self.visitStructOrUnionType(t, TypeVisitor.visitUnionType) + + def visitArrayType(self, t): + return TypeVisitor.visitArrayType(self, t) + + def visitShmemType(self, s): + if s in self.visited: return + self.visited.add(s) + self.maybeTypedef('mozilla::ipc::Shmem', 'Shmem') + + def visitFDType(self, s): + if s in self.visited: return + self.visited.add(s) + self.maybeTypedef('mozilla::ipc::FileDescriptor', 'FileDescriptor') + + def visitVoidType(self, v): assert 0 + def visitMessageType(self, v): assert 0 + def visitProtocolType(self, v): assert 0 + def visitStateType(self, v): assert 0 + + +def _generateCxxStruct(sd): + ''' ''' + # compute all the typedefs and forward decls we need to make + gettypedeps = _ComputeTypeDeps(sd.decl.type) + for f in sd.fields: + f.ipdltype.accept(gettypedeps) + + usingTypedefs = gettypedeps.usingTypedefs + forwarddeclstmts = gettypedeps.forwardDeclStmts + + struct = Class(sd.name, final=1) + struct.addstmts([ Label.PRIVATE ] + + usingTypedefs + + [ Whitespace.NL, Label.PUBLIC ]) + + constreftype = Type(sd.name, const=1, ref=1) + initvar = ExprVar('Init') + callinit = ExprCall(initvar) + assignvar = ExprVar('Assign') + + def fieldsAsParamList(): + return [ Decl(f.inType(), f.argVar().name) for f in sd.fields ] + + def assignFromOther(oexpr): + return ExprCall(assignvar, + args=[ f.initExpr(oexpr) for f in sd.fields ]) + + # If this is an empty struct (no fields), then the default ctor + # and "create-with-fields" ctors are equivalent. So don't bother + # with the default ctor. + if len(sd.fields): + # Struct() + defctor = ConstructorDefn(ConstructorDecl(sd.name)) + defctor.addstmt(StmtExpr(callinit)) + defctor.memberinits = [] + for f in sd.fields: + # Only generate default values for primitives. + if not (f.ipdltype.isCxx() and f.ipdltype.isAtom()): + continue + defctor.memberinits.append(ExprMemberInit(f.memberVar())) + struct.addstmts([ defctor, Whitespace.NL ]) + + # Struct(const field1& _f1, ...) + valctor = ConstructorDefn(ConstructorDecl(sd.name, + params=fieldsAsParamList(), + force_inline=1)) + valctor.addstmts([ + StmtExpr(callinit), + StmtExpr(ExprCall(assignvar, + args=[ f.argVar() for f in sd.fields ])) + ]) + struct.addstmts([ valctor, Whitespace.NL ]) + + # Struct(const Struct& _o) + ovar = ExprVar('_o') + copyctor = ConstructorDefn(ConstructorDecl( + sd.name, + params=[ Decl(constreftype, ovar.name) ], + force_inline=1)) + copyctor.addstmts([ + StmtExpr(callinit), + StmtExpr(assignFromOther(ovar)) + ]) + struct.addstmts([ copyctor, Whitespace.NL ]) + + # ~Struct() + dtor = DestructorDefn(DestructorDecl(sd.name)) + for f in sd.fields: + dtor.addstmts(f.destructStmts()) + struct.addstmts([ dtor, Whitespace.NL ]) + + # Struct& operator=(const Struct& _o) + opeq = MethodDefn(MethodDecl( + 'operator=', + params=[ Decl(constreftype, ovar.name) ], + force_inline=1)) + opeq.addstmt(StmtExpr(assignFromOther(ovar))) + struct.addstmts([ opeq, Whitespace.NL ]) + + # bool operator==(const Struct& _o) + opeqeq = MethodDefn(MethodDecl( + 'operator==', + params=[ Decl(constreftype, ovar.name) ], + ret=Type.BOOL, + const=1)) + for f in sd.fields: + ifneq = StmtIf(ExprNot( + ExprBinary(ExprCall(f.getMethod()), '==', + ExprCall(f.getMethod(ovar))))) + ifneq.addifstmt(StmtReturn.FALSE) + opeqeq.addstmt(ifneq) + opeqeq.addstmt(StmtReturn.TRUE) + struct.addstmts([ opeqeq, Whitespace.NL ]) + + # field1& f1() + # const field1& f1() const + for f in sd.fields: + get = MethodDefn(MethodDecl(f.getMethod().name, + params=[ ], + ret=f.refType(), + force_inline=1)) + get.addstmt(StmtReturn(f.refExpr())) + + getconstdecl = deepcopy(get.decl) + getconstdecl.ret = f.constRefType() + getconstdecl.const = 1 + getconst = MethodDefn(getconstdecl) + getconst.addstmt(StmtReturn(f.constRefExpr())) + + struct.addstmts([ get, getconst, Whitespace.NL ]) + + # private: + struct.addstmt(Label.PRIVATE) + + # Init() + init = MethodDefn(MethodDecl(initvar.name)) + for f in sd.fields: + init.addstmts(f.initStmts()) + struct.addstmts([ init, Whitespace.NL ]) + + # Assign(const field1& _f1, ...) + assign = MethodDefn(MethodDecl(assignvar.name, + params=fieldsAsParamList())) + assign.addstmts([ StmtExpr(ExprAssn(f.refExpr(), f.argVar())) + for f in sd.fields ]) + struct.addstmts([ assign, Whitespace.NL ]) + + # members + struct.addstmts([ StmtDecl(Decl(f.internalType(), f.memberVar().name)) + for f in sd.fields ]) + + return forwarddeclstmts, struct + +##-------------------------------------------------- + +def _generateCxxUnion(ud): + # This Union class basically consists of a type (enum) and a + # union for storage. The union can contain POD and non-POD + # types. Each type needs a copy ctor, assignment operator, + # and dtor. + # + # Rather than templating this class and only providing + # specializations for the types we support, which is slightly + # "unsafe" in that C++ code can add additional specializations + # without the IPDL compiler's knowledge, we instead explicitly + # implement non-templated methods for each supported type. + # + # The one complication that arises is that C++, for arcane + # reasons, does not allow the placement destructor of a + # builtin type, like int, to be directly invoked. So we need + # to hack around this by internally typedef'ing all + # constituent types. Sigh. + # + # So, for each type, this "Union" class needs: + # (private) + # - entry in the type enum + # - entry in the storage union + # - [type]ptr() method to get a type* from the underlying union + # - same as above to get a const type* + # - typedef to hack around placement delete limitations + # (public) + # - placement delete case for dtor + # - copy ctor + # - case in generic copy ctor + # - operator= impl + # - case in generic operator= + # - operator [type&] + # - operator [const type&] const + # - [type&] get_[type]() + # - [const type&] get_[type]() const + # + cls = Class(ud.name, final=1) + # const Union&, i.e., Union type with inparam semantics + inClsType = Type(ud.name, const=1, ref=1) + refClsType = Type(ud.name, ref=1) + typetype = Type('Type') + valuetype = Type('Value') + mtypevar = ExprVar('mType') + mvaluevar = ExprVar('mValue') + maybedtorvar = ExprVar('MaybeDestroy') + assertsanityvar = ExprVar('AssertSanity') + tnonevar = ExprVar('T__None') + tlastvar = ExprVar('T__Last') + + def callAssertSanity(uvar=None, expectTypeVar=None): + func = assertsanityvar + args = [ ] + if uvar is not None: + func = ExprSelect(uvar, '.', assertsanityvar.name) + if expectTypeVar is not None: + args.append(expectTypeVar) + return ExprCall(func, args=args) + + def callMaybeDestroy(newTypeVar): + return ExprCall(maybedtorvar, args=[ newTypeVar ]) + + def maybeReconstruct(memb, newTypeVar): + ifdied = StmtIf(callMaybeDestroy(newTypeVar)) + ifdied.addifstmt(StmtExpr(memb.callCtor())) + return ifdied + + # compute all the typedefs and forward decls we need to make + gettypedeps = _ComputeTypeDeps(ud.decl.type) + for c in ud.components: + c.ipdltype.accept(gettypedeps) + + usingTypedefs = gettypedeps.usingTypedefs + forwarddeclstmts = gettypedeps.forwardDeclStmts + + # the |Type| enum, used to switch on the discunion's real type + cls.addstmt(Label.PUBLIC) + typeenum = TypeEnum(typetype.name) + typeenum.addId(tnonevar.name, 0) + firstid = ud.components[0].enum() + typeenum.addId(firstid, 1) + for c in ud.components[1:]: + typeenum.addId(c.enum()) + typeenum.addId(tlastvar.name, ud.components[-1].enum()) + cls.addstmts([ StmtDecl(Decl(typeenum,'')), + Whitespace.NL ]) + + cls.addstmt(Label.PRIVATE) + cls.addstmts( + usingTypedefs + # hacky typedef's that allow placement dtors of builtins + + [ Typedef(c.internalType(), c.typedef()) for c in ud.components ]) + cls.addstmt(Whitespace.NL) + + # the C++ union the discunion use for storage + valueunion = TypeUnion(valuetype.name) + for c in ud.components: + valueunion.addComponent(c.unionType(), c.name) + cls.addstmts([ StmtDecl(Decl(valueunion,'')), + Whitespace.NL ]) + + # for each constituent type T, add private accessors that + # return a pointer to the Value union storage casted to |T*| + # and |const T*| + for c in ud.components: + getptr = MethodDefn(MethodDecl( + c.getPtrName(), params=[ ], ret=c.ptrToInternalType(), + force_inline=1)) + getptr.addstmt(StmtReturn(c.ptrToSelfExpr())) + + getptrconst = MethodDefn(MethodDecl( + c.getConstPtrName(), params=[ ], ret=c.constPtrToType(), + const=1, force_inline=1)) + getptrconst.addstmt(StmtReturn(c.constptrToSelfExpr())) + + cls.addstmts([ getptr, getptrconst ]) + cls.addstmt(Whitespace.NL) + + # add a helper method that invokes the placement dtor on the + # current underlying value, only if |aNewType| is different + # than the current type, and returns true if the underlying + # value needs to be re-constructed + newtypevar = ExprVar('aNewType') + maybedtor = MethodDefn(MethodDecl( + maybedtorvar.name, + params=[ Decl(typetype, newtypevar.name) ], + ret=Type.BOOL)) + # wasn't /actually/ dtor'd, but it needs to be re-constructed + ifnone = StmtIf(ExprBinary(mtypevar, '==', tnonevar)) + ifnone.addifstmt(StmtReturn.TRUE) + # same type, nothing to see here + ifnochange = StmtIf(ExprBinary(mtypevar, '==', newtypevar)) + ifnochange.addifstmt(StmtReturn.FALSE) + # need to destroy. switch on underlying type + dtorswitch = StmtSwitch(mtypevar) + for c in ud.components: + dtorswitch.addcase( + CaseLabel(c.enum()), + StmtBlock([ StmtExpr(c.callDtor()), + StmtBreak() ])) + dtorswitch.addcase( + DefaultLabel(), + StmtBlock([ _logicError("not reached"), StmtBreak() ])) + maybedtor.addstmts([ + ifnone, + ifnochange, + dtorswitch, + StmtReturn.TRUE + ]) + cls.addstmts([ maybedtor, Whitespace.NL ]) + + # add helper methods that ensure the discunion has a + # valid type + sanity = MethodDefn(MethodDecl( + assertsanityvar.name, ret=Type.VOID, const=1, force_inline=1)) + sanity.addstmts([ + _abortIfFalse(ExprBinary(tnonevar, '<=', mtypevar), + 'invalid type tag'), + _abortIfFalse(ExprBinary(mtypevar, '<=', tlastvar), + 'invalid type tag') ]) + cls.addstmt(sanity) + + atypevar = ExprVar('aType') + sanity2 = MethodDefn( + MethodDecl(assertsanityvar.name, + params=[ Decl(typetype, atypevar.name) ], + ret=Type.VOID, + const=1, force_inline=1)) + sanity2.addstmts([ + StmtExpr(ExprCall(assertsanityvar)), + _abortIfFalse(ExprBinary(mtypevar, '==', atypevar), + 'unexpected type tag') ]) + cls.addstmts([ sanity2, Whitespace.NL ]) + + ## ---- begin public methods ----- + + # Union() default ctor + cls.addstmts([ + Label.PUBLIC, + ConstructorDefn( + ConstructorDecl(ud.name, force_inline=1), + memberinits=[ ExprMemberInit(mtypevar, [ tnonevar ]) ]), + Whitespace.NL + ]) + + # Union(const T&) copy ctors + othervar = ExprVar('aOther') + for c in ud.components: + copyctor = ConstructorDefn(ConstructorDecl( + ud.name, params=[ Decl(c.inType(), othervar.name) ])) + copyctor.addstmts([ + StmtExpr(c.callCtor(othervar)), + StmtExpr(ExprAssn(mtypevar, c.enumvar())) ]) + cls.addstmts([ copyctor, Whitespace.NL ]) + + # Union(const Union&) copy ctor + copyctor = ConstructorDefn(ConstructorDecl( + ud.name, params=[ Decl(inClsType, othervar.name) ])) + othertype = ud.callType(othervar) + copyswitch = StmtSwitch(othertype) + for c in ud.components: + copyswitch.addcase( + CaseLabel(c.enum()), + StmtBlock([ + StmtExpr(c.callCtor( + ExprCall(ExprSelect(othervar, + '.', c.getConstTypeName())))), + StmtBreak() + ])) + copyswitch.addcase(CaseLabel(tnonevar.name), + StmtBlock([ StmtBreak() ])) + copyswitch.addcase( + DefaultLabel(), + StmtBlock([ _logicError('unreached'), StmtReturn() ])) + copyctor.addstmts([ + StmtExpr(callAssertSanity(uvar=othervar)), + copyswitch, + StmtExpr(ExprAssn(mtypevar, othertype)) + ]) + cls.addstmts([ copyctor, Whitespace.NL ]) + + # ~Union() + dtor = DestructorDefn(DestructorDecl(ud.name)) + # The void cast prevents Coverity from complaining about missing return + # value checks. + dtor.addstmt(StmtExpr(ExprCast(callMaybeDestroy(tnonevar), Type.VOID, + static=1))) + cls.addstmts([ dtor, Whitespace.NL ]) + + # type() + typemeth = MethodDefn(MethodDecl('type', ret=typetype, + const=1, force_inline=1)) + typemeth.addstmt(StmtReturn(mtypevar)) + cls.addstmts([ typemeth, Whitespace.NL ]) + + # Union& operator=(const T&) methods + rhsvar = ExprVar('aRhs') + for c in ud.components: + opeq = MethodDefn(MethodDecl( + 'operator=', + params=[ Decl(c.inType(), rhsvar.name) ], + ret=refClsType)) + opeq.addstmts([ + # might need to placement-delete old value first + maybeReconstruct(c, c.enumvar()), + StmtExpr(c.callOperatorEq(rhsvar)), + StmtExpr(ExprAssn(mtypevar, c.enumvar())), + StmtReturn(ExprDeref(ExprVar.THIS)) + ]) + cls.addstmts([ opeq, Whitespace.NL ]) + + # Union& operator=(const Union&) + opeq = MethodDefn(MethodDecl( + 'operator=', + params=[ Decl(inClsType, rhsvar.name) ], + ret=refClsType)) + rhstypevar = ExprVar('t') + opeqswitch = StmtSwitch(rhstypevar) + for c in ud.components: + case = StmtBlock() + case.addstmts([ + maybeReconstruct(c, rhstypevar), + StmtExpr(c.callOperatorEq( + ExprCall(ExprSelect(rhsvar, '.', c.getConstTypeName())))), + StmtBreak() + ]) + opeqswitch.addcase(CaseLabel(c.enum()), case) + opeqswitch.addcase( + CaseLabel(tnonevar.name), + # The void cast prevents Coverity from complaining about missing return + # value checks. + StmtBlock([ StmtExpr(ExprCast(callMaybeDestroy(rhstypevar), Type.VOID, + static=1)), + StmtBreak() ]) + ) + opeqswitch.addcase( + DefaultLabel(), + StmtBlock([ _logicError('unreached'), StmtBreak() ])) + opeq.addstmts([ + StmtExpr(callAssertSanity(uvar=rhsvar)), + StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)), + opeqswitch, + StmtExpr(ExprAssn(mtypevar, rhstypevar)), + StmtReturn(ExprDeref(ExprVar.THIS)) + ]) + cls.addstmts([ opeq, Whitespace.NL ]) + + # bool operator==(const T&) + for c in ud.components: + opeqeq = MethodDefn(MethodDecl( + 'operator==', + params=[ Decl(c.inType(), rhsvar.name) ], + ret=Type.BOOL, + const=1)) + opeqeq.addstmt(StmtReturn(ExprBinary( + ExprCall(ExprVar(c.getTypeName())), '==', rhsvar))) + cls.addstmts([ opeqeq, Whitespace.NL ]) + + # bool operator==(const Union&) + opeqeq = MethodDefn(MethodDecl( + 'operator==', + params=[ Decl(inClsType, rhsvar.name) ], + ret=Type.BOOL, + const=1)) + iftypesmismatch = StmtIf(ExprBinary(ud.callType(), '!=', + ud.callType(rhsvar))) + iftypesmismatch.addifstmt(StmtReturn.FALSE) + opeqeq.addstmts([ iftypesmismatch, Whitespace.NL ]) + + opeqeqswitch = StmtSwitch(ud.callType()) + for c in ud.components: + case = StmtBlock() + case.addstmt(StmtReturn(ExprBinary( + ExprCall(ExprVar(c.getTypeName())), '==', + ExprCall(ExprSelect(rhsvar, '.', c.getTypeName()))))) + opeqeqswitch.addcase(CaseLabel(c.enum()), case) + opeqeqswitch.addcase( + DefaultLabel(), + StmtBlock([ _logicError('unreached'), + StmtReturn.FALSE ])) + opeqeq.addstmt(opeqeqswitch) + + cls.addstmts([ opeqeq, Whitespace.NL ]) + + # accessors for each type: operator T&, operator const T&, + # T& get(), const T& get() + for c in ud.components: + getValueVar = ExprVar(c.getTypeName()) + getConstValueVar = ExprVar(c.getConstTypeName()) + + getvalue = MethodDefn(MethodDecl(getValueVar.name, + ret=c.refType(), + force_inline=1)) + getvalue.addstmts([ + StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())), + StmtReturn(ExprDeref(c.callGetPtr())) + ]) + + getconstvalue = MethodDefn(MethodDecl( + getConstValueVar.name, ret=c.constRefType(), + const=1, force_inline=1)) + getconstvalue.addstmts([ + StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())), + StmtReturn(c.getConstValue()) + ]) + + readvalue = MethodDefn(MethodDecl( + 'get', ret=Type.VOID, const=1, + params=[Decl(c.ptrToType(), 'aOutValue')])) + readvalue.addstmts([ + StmtExpr(ExprAssn(ExprDeref(ExprVar('aOutValue')), + ExprCall(getConstValueVar))) + ]) + + optype = MethodDefn(MethodDecl('', typeop=c.refType(), force_inline=1)) + optype.addstmt(StmtReturn(ExprCall(getValueVar))) + opconsttype = MethodDefn(MethodDecl( + '', const=1, typeop=c.constRefType(), force_inline=1)) + opconsttype.addstmt(StmtReturn(ExprCall(getConstValueVar))) + + cls.addstmts([ getvalue, getconstvalue, readvalue, + optype, opconsttype, + Whitespace.NL ]) + + # private vars + cls.addstmts([ + Label.PRIVATE, + StmtDecl(Decl(valuetype, mvaluevar.name)), + StmtDecl(Decl(typetype, mtypevar.name)) + ]) + + return forwarddeclstmts, cls + +##----------------------------------------------------------------------------- + +class _FindFriends(ipdl.ast.Visitor): + def __init__(self): + self.mytype = None # ProtocolType + self.vtype = None # ProtocolType + self.friends = set() # set<ProtocolType> + + def findFriends(self, ptype): + self.mytype = ptype + for toplvl in ptype.toplevels(): + self.walkDownTheProtocolTree(toplvl); + return self.friends + + # TODO could make this into a _iterProtocolTreeHelper ... + def walkDownTheProtocolTree(self, ptype): + if ptype != self.mytype: + # don't want to |friend| ourself! + self.visit(ptype) + for mtype in ptype.manages: + if mtype is not ptype: + self.walkDownTheProtocolTree(mtype) + + def visit(self, ptype): + # |vtype| is the type currently being visited + savedptype = self.vtype + self.vtype = ptype + ptype._ast.accept(self) + self.vtype = savedptype + + def visitMessageDecl(self, md): + for it in self.iterActorParams(md): + if it.protocol == self.mytype: + self.friends.add(self.vtype) + + def iterActorParams(self, md): + for param in md.inParams: + for actor in ipdl.type.iteractortypes(param.type): + yield actor + for ret in md.outParams: + for actor in ipdl.type.iteractortypes(ret.type): + yield actor + + +class _GenerateProtocolActorCode(ipdl.ast.Visitor): + def __init__(self, myside): + self.side = myside # "parent" or "child" + self.prettyside = myside.title() + self.clsname = None + self.protocol = None + self.hdrfile = None + self.cppfile = None + self.ns = None + self.cls = None + self.includedActorTypedefs = [ ] + self.protocolCxxIncludes = [ ] + self.actorForwardDecls = [ ] + self.usingDecls = [ ] + self.externalIncludes = set() + self.nonForwardDeclaredHeaders = set() + + def lower(self, tu, clsname, cxxHeaderFile, cxxFile): + self.clsname = clsname + self.hdrfile = cxxHeaderFile + self.cppfile = cxxFile + tu.accept(self) + + def standardTypedefs(self): + return [ + Typedef(Type('mozilla::ipc::IProtocol'), 'ProtocolBase'), + Typedef(Type('IPC::Message'), 'Message'), + Typedef(Type(self.protocol.channelName()), 'Channel'), + Typedef(Type('mozilla::ipc::IProtocol'), 'ChannelListener'), + Typedef(Type('base::ProcessHandle'), 'ProcessHandle'), + Typedef(Type('mozilla::ipc::MessageChannel'), 'MessageChannel'), + Typedef(Type('mozilla::ipc::SharedMemory'), 'SharedMemory'), + Typedef(Type('mozilla::ipc::Trigger'), 'Trigger'), + ] + + + def visitTranslationUnit(self, tu): + self.protocol = tu.protocol + + hf = self.hdrfile + cf = self.cppfile + + # make the C++ header + hf.addthings( + [ _DISCLAIMER ] + + _includeGuardStart(hf) + +[ + Whitespace.NL, + CppDirective( + 'include', + '"'+ _protocolHeaderName(tu.protocol) +'.h"') + ]) + + for inc in tu.includes: + inc.accept(self) + for inc in tu.cxxIncludes: + inc.accept(self) + + for using in tu.using: + using.accept(self) + + # this generates the actor's full impl in self.cls + tu.protocol.accept(self) + + clsdecl, clsdefn = _splitClassDeclDefn(self.cls) + + # XXX damn C++ ... return types in the method defn aren't in + # class scope + for stmt in clsdefn.stmts: + if isinstance(stmt, MethodDefn): + if stmt.decl.ret and stmt.decl.ret.name == 'Result': + stmt.decl.ret.name = clsdecl.name +'::'+ stmt.decl.ret.name + + def setToIncludes(s): + return [ CppDirective('include', '"%s"' % i) + for i in sorted(iter(s)) ] + + def makeNamespace(p, file): + if 0 == len(p.namespaces): + return file + ns = Namespace(p.namespaces[-1].name) + outerns = _putInNamespaces(ns, p.namespaces[:-1]) + file.addthing(outerns) + return ns + + if len(self.nonForwardDeclaredHeaders) != 0: + self.hdrfile.addthings( + [ Whitespace('// Headers for things that cannot be forward declared'), + Whitespace.NL ] + + setToIncludes(self.nonForwardDeclaredHeaders) + + [ Whitespace.NL ] + ) + self.hdrfile.addthings(self.actorForwardDecls) + self.hdrfile.addthings(self.usingDecls) + + hdrns = makeNamespace(self.protocol, self.hdrfile) + hdrns.addstmts([ + Whitespace.NL, + Whitespace.NL, + clsdecl, + Whitespace.NL, + Whitespace.NL + ]) + + self.hdrfile.addthings( + ([ + Whitespace.NL, + CppDirective('if', '0') ]) + + _GenerateSkeletonImpl( + _actorName(self.protocol.name, self.side)[1:], + self.protocol.namespaces).fromclass(self.cls) + +([ + CppDirective('endif', '// if 0'), + Whitespace.NL ]) + + _includeGuardEnd(hf)) + + # make the .cpp file + cf.addthings([ + _DISCLAIMER, + Whitespace.NL, + CppDirective( + 'include', + '"'+ _protocolHeaderName(self.protocol, self.side) +'.h"') ] + + setToIncludes(self.externalIncludes)) + + if self.protocol.decl.type.isToplevel(): + cf.addthings([ + CppDirective('ifdef', 'MOZ_CRASHREPORTER'), + CppDirective(' include', '"nsXULAppAPI.h"'), + CppDirective('endif') + ]) + + cppheaders = [CppDirective('include', '"%s"' % filename) + for filename in ipdl.builtin.CppIncludes] + + cf.addthings(( + [ Whitespace.NL ] + + [ CppDirective( + 'include', + '"%s.h"' % (inc)) for inc in self.protocolCxxIncludes ] + + [ Whitespace.NL ] + + cppheaders + + [ Whitespace.NL ])) + + cppns = makeNamespace(self.protocol, cf) + cppns.addstmts([ + Whitespace.NL, + Whitespace.NL, + clsdefn, + Whitespace.NL, + Whitespace.NL + ]) + + def visitUsingStmt(self, using): + if using.header is None: + return + + if using.canBeForwardDeclared(): + spec = using.type.spec + + self.usingDecls.extend([ + _makeForwardDeclForQClass(spec.baseid, spec.quals, + cls=using.isClass(), + struct=using.isStruct()), + Whitespace.NL + ]) + self.externalIncludes.add(using.header) + else: + self.nonForwardDeclaredHeaders.add(using.header) + + def visitCxxInclude(self, inc): + self.nonForwardDeclaredHeaders.add(inc.file) + + def visitInclude(self, inc): + ip = inc.tu.protocol + if not ip: + return + + self.actorForwardDecls.extend([ + _makeForwardDeclForActor(ip.decl.type, self.side), + _makeForwardDeclForActor(ip.decl.type, _otherSide(self.side)), + Whitespace.NL + ]) + self.protocolCxxIncludes.append(_protocolHeaderName(ip, self.side)) + + if ip.decl.fullname is not None: + self.includedActorTypedefs.append(Typedef( + Type(_actorName(ip.decl.fullname, self.side.title())), + _actorName(ip.decl.shortname, self.side.title()))) + + self.includedActorTypedefs.append(Typedef( + Type(_actorName(ip.decl.fullname, _otherSide(self.side).title())), + _actorName(ip.decl.shortname, _otherSide(self.side).title()))) + + + def visitProtocol(self, p): + self.hdrfile.addthings([ + CppDirective('ifdef', 'DEBUG'), + CppDirective('include', '"prenv.h"'), + CppDirective('endif', '// DEBUG') + ]) + + self.protocol = p + ptype = p.decl.type + toplevel = p.decl.type.toplevel() + + # FIXME: all actors impl Iface for now + if ptype.isManager() or 1: + self.hdrfile.addthing(CppDirective('include', '"base/id_map.h"')) + + self.hdrfile.addthings([ + CppDirective('include', '"'+ p.channelHeaderFile() +'"'), + Whitespace.NL ]) + + inherits = [] + if ptype.isToplevel(): + inherits.append(Inherit(p.openedProtocolInterfaceType(), + viz='public')) + else: + inherits.append(Inherit(p.managerInterfaceType(), viz='public')) + + if ptype.isToplevel() and self.side is 'parent': + self.hdrfile.addthings([ + _makeForwardDeclForQClass('nsIFile', []), + Whitespace.NL + ]) + + self.cls = Class( + self.clsname, + inherits=inherits, + abstract=True) + + bridgeActorsCreated = ProcessGraph.bridgeEndpointsOf(ptype, self.side) + opensActorsCreated = ProcessGraph.opensEndpointsOf(ptype, self.side) + channelOpenedActors = OrderedDict.fromkeys(bridgeActorsCreated + opensActorsCreated, None) + + friends = _FindFriends().findFriends(ptype) + if ptype.isManaged(): + friends.update(ptype.managers) + + # |friend| managed actors so that they can call our Dealloc*() + friends.update(ptype.manages) + + # don't friend ourself if we're a self-managed protocol + friends.discard(ptype) + + for friend in friends: + self.actorForwardDecls.extend([ + _makeForwardDeclForActor(friend, self.prettyside), + Whitespace.NL + ]) + self.cls.addstmts([ + FriendClassDecl(_actorName(friend.fullname(), + self.prettyside)), + Whitespace.NL ]) + + for actor in channelOpenedActors: + self.hdrfile.addthings([ + Whitespace.NL, + _makeForwardDeclForActor(actor.ptype, actor.side), + Whitespace.NL + ]) + + self.cls.addstmt(Label.PROTECTED) + for typedef in p.cxxTypedefs(): + self.cls.addstmt(typedef) + for typedef in self.includedActorTypedefs: + self.cls.addstmt(typedef) + + self.cls.addstmt(Whitespace.NL) + + self.cls.addstmts([ Typedef(p.fqStateType(), 'State'), Whitespace.NL ]) + + # interface methods that the concrete subclass has to impl + for md in p.messageDecls: + isctor, isdtor = md.decl.type.isCtor(), md.decl.type.isDtor() + + if self.receivesMessage(md): + # generate Recv/Answer* interface + implicit = (not isdtor) + recvDecl = MethodDecl( + md.recvMethod().name, + params=md.makeCxxParams(paramsems='move', returnsems='out', + side=self.side, implicit=implicit), + ret=Type.BOOL, virtual=1) + + if isctor or isdtor: + defaultRecv = MethodDefn(recvDecl) + defaultRecv.addstmt(StmtReturn.TRUE) + self.cls.addstmt(defaultRecv) + else: + recvDecl.pure = 1 + self.cls.addstmt(StmtDecl(recvDecl)) + + for md in p.messageDecls: + managed = md.decl.type.constructedType() + if not ptype.isManagerOf(managed) or md.decl.type.isDtor(): + continue + + # add the Alloc/Dealloc interface for managed actors + actortype = md.actorDecl().bareType(self.side) + + self.cls.addstmt(StmtDecl(MethodDecl( + _allocMethod(managed, self.side).name, + params=md.makeCxxParams(side=self.side, implicit=0), + ret=actortype, + virtual=1, pure=1))) + + self.cls.addstmt(StmtDecl(MethodDecl( + _deallocMethod(managed, self.side).name, + params=[ Decl(actortype, 'aActor') ], + ret=Type.BOOL, + virtual=1, pure=1))) + + for actor in channelOpenedActors: + # add the Alloc interface for actors created when a + # new channel is opened + actortype = _cxxBareType(actor.asType(), actor.side) + self.cls.addstmt(StmtDecl(MethodDecl( + _allocMethod(actor.ptype, actor.side).name, + params=[ Decl(Type('Transport', ptr=1), 'aTransport'), + Decl(Type('ProcessId'), 'aOtherPid') ], + ret=actortype, + virtual=1, pure=1))) + + # ActorDestroy() method; default is no-op + self.cls.addstmts([ + Whitespace.NL, + MethodDefn(MethodDecl( + _destroyMethod().name, + params=[ Decl(_DestroyReason.Type(), 'aWhy') ], + ret=Type.VOID, + virtual=1, pure=(self.side == 'parent'))), + Whitespace.NL + ]) + + if ptype.isToplevel(): + # void ProcessingError(code); default to no-op + processingerror = MethodDefn( + MethodDecl(p.processingErrorVar().name, + params=[ Param(_Result.Type(), 'aCode'), + Param(Type('char', const=1, ptr=1), 'aReason') ], + virtual=1)) + + # bool ShouldContinueFromReplyTimeout(); default to |true| + shouldcontinue = MethodDefn( + MethodDecl(p.shouldContinueFromTimeoutVar().name, + ret=Type.BOOL, virtual=1)) + shouldcontinue.addstmt(StmtReturn.TRUE) + + # void Entered*()/Exited*(); default to no-op + entered = MethodDefn( + MethodDecl(p.enteredCxxStackVar().name, virtual=1)) + exited = MethodDefn( + MethodDecl(p.exitedCxxStackVar().name, virtual=1)) + enteredcall = MethodDefn( + MethodDecl(p.enteredCallVar().name, virtual=1)) + exitedcall = MethodDefn( + MethodDecl(p.exitedCallVar().name, virtual=1)) + + self.cls.addstmts([ processingerror, + shouldcontinue, + entered, exited, + enteredcall, exitedcall, + Whitespace.NL ]) + + self.cls.addstmts(( + [ Label.PUBLIC ] + + self.standardTypedefs() + + [ Whitespace.NL ] + )) + + self.cls.addstmt(Label.PUBLIC) + # Actor() + ctor = ConstructorDefn(ConstructorDecl(self.clsname)) + side = ExprVar('mozilla::ipc::' + self.side.title() + 'Side') + if ptype.isToplevel(): + ctor.memberinits = [ + ExprMemberInit(ExprVar('mozilla::ipc::IToplevelProtocol'), + [_protocolId(ptype), side]), + ExprMemberInit(p.channelVar(), [ + ExprCall(ExprVar('ALLOW_THIS_IN_INITIALIZER_LIST'), + [ ExprVar.THIS ]) ]), + ExprMemberInit(p.stateVar(), + [ p.startState() ]) + ] + else: + ctor.memberinits = [ + ExprMemberInit(ExprVar('mozilla::ipc::IProtocol'), [side]), + ExprMemberInit(p.stateVar(), + [ p.deadState() ]) + ] + + ctor.addstmt(StmtExpr(ExprCall(ExprVar('MOZ_COUNT_CTOR'), + [ ExprVar(self.clsname) ]))) + self.cls.addstmts([ ctor, Whitespace.NL ]) + + # ~Actor() + dtor = DestructorDefn( + DestructorDecl(self.clsname, virtual=True)) + dtor.addstmt(StmtExpr(ExprCall(ExprVar('MOZ_COUNT_DTOR'), + [ ExprVar(self.clsname) ]))) + + self.cls.addstmts([ dtor, Whitespace.NL ]) + + if not ptype.isToplevel(): + if 1 == len(p.managers): + ## manager() const + managertype = p.managerActorType(self.side, ptr=1) + managermeth = MethodDefn(MethodDecl( + 'Manager', ret=managertype, const=1)) + managerexp = ExprCall(ExprVar('IProtocol::Manager'), args=[]) + managermeth.addstmt(StmtReturn( + ExprCast(managerexp, managertype, static=1))) + + self.cls.addstmts([ managermeth, Whitespace.NL ]) + + def actorFromIter(itervar): + return ExprCall(ExprSelect(ExprCall(ExprSelect(itervar, '.', 'Get')), + '->', 'GetKey')) + def forLoopOverHashtable(hashtable, itervar, const=False): + return StmtFor( + init=Param(Type.AUTO, itervar.name, + ExprCall(ExprSelect(hashtable, '.', 'ConstIter' if const else 'Iter'))), + cond=ExprNot(ExprCall(ExprSelect(itervar, '.', 'Done'))), + update=ExprCall(ExprSelect(itervar, '.', 'Next'))) + + ## Managed[T](Array& inout) const + ## const Array<T>& Managed() const + for managed in ptype.manages: + arrvar = ExprVar('aArr') + meth = MethodDefn(MethodDecl( + p.managedMethod(managed, self.side).name, + params=[ Decl(_cxxArrayType(p.managedCxxType(managed, self.side), ref=1), + arrvar.name) ], + const=1)) + meth.addstmt(StmtExpr( + ExprCall(ExprSelect(p.managedVar(managed, self.side), + '.', 'ToArray'), + args=[ arrvar ]))) + + refmeth = MethodDefn(MethodDecl( + p.managedMethod(managed, self.side).name, + params=[ ], + ret=p.managedVarType(managed, self.side, const=1, ref=1), + const=1)) + refmeth.addstmt(StmtReturn(p.managedVar(managed, self.side))) + + self.cls.addstmts([ meth, refmeth, Whitespace.NL ]) + + statemethod = MethodDefn(MethodDecl( + p.stateMethod().name, + ret=p.fqStateType())) + statemethod.addstmt(StmtReturn(p.stateVar())) + self.cls.addstmts([ statemethod, Whitespace.NL ]) + + ## OnMessageReceived()/OnCallReceived() + + # save these away for use in message handler case stmts + msgvar = ExprVar('msg__') + self.msgvar = msgvar + replyvar = ExprVar('reply__') + self.replyvar = replyvar + itervar = ExprVar('iter__') + self.itervar = itervar + var = ExprVar('v__') + self.var = var + # for ctor recv cases, we can't read the actor ID into a PFoo* + # because it doesn't exist on this side yet. Use a "special" + # actor handle instead + handlevar = ExprVar('handle__') + self.handlevar = handlevar + + msgtype = ExprCall(ExprSelect(msgvar, '.', 'type'), [ ]) + self.asyncSwitch = StmtSwitch(msgtype) + self.syncSwitch = None + self.interruptSwitch = None + if toplevel.isSync() or toplevel.isInterrupt(): + self.syncSwitch = StmtSwitch(msgtype) + if toplevel.isInterrupt(): + self.interruptSwitch = StmtSwitch(msgtype) + + # implement Send*() methods and add dispatcher cases to + # message switch()es + for md in p.messageDecls: + self.visitMessageDecl(md) + + # Handlers for the creation of actors when a new channel is + # opened + if len(channelOpenedActors): + self.makeChannelOpenedHandlers(channelOpenedActors) + + # add default cases + default = StmtBlock() + default.addstmt(StmtReturn(_Result.NotKnown)) + self.asyncSwitch.addcase(DefaultLabel(), default) + if toplevel.isSync() or toplevel.isInterrupt(): + self.syncSwitch.addcase(DefaultLabel(), default) + if toplevel.isInterrupt(): + self.interruptSwitch.addcase(DefaultLabel(), default) + + # FIXME/bug 535053: only manager protocols and non-manager + # protocols with union types need Lookup(). we'll give it to + # all for the time being (simpler) + if 1 or ptype.isManager(): + self.cls.addstmts(self.implementManagerIface()) + + def makeHandlerMethod(name, switch, hasReply, dispatches=0): + params = [ Decl(Type('Message', const=1, ref=1), msgvar.name) ] + if hasReply: + params.append(Decl(Type('Message', ref=1, ptr=1), + replyvar.name)) + + method = MethodDefn(MethodDecl(name, virtual=True, + params=params, ret=_Result.Type())) + + if not switch: + crash = StmtExpr(ExprCall(ExprVar('MOZ_ASSERT_UNREACHABLE'), + args=[ExprLiteral.String('message protocol not supported')])) + method.addstmts([crash, StmtReturn(_Result.NotKnown)]) + return method + + if dispatches: + routevar = ExprVar('route__') + routedecl = StmtDecl( + Decl(_actorIdType(), routevar.name), + init=ExprCall(ExprSelect(msgvar, '.', 'routing_id'))) + + routeif = StmtIf(ExprBinary( + ExprVar('MSG_ROUTING_CONTROL'), '!=', routevar)) + routedvar = ExprVar('routed__') + routeif.ifb.addstmt( + StmtDecl(Decl(Type('ChannelListener', ptr=1), + routedvar.name), + _lookupListener(routevar))) + failif = StmtIf(ExprPrefixUnop(routedvar, '!')) + failif.ifb.addstmt(StmtReturn(_Result.RouteError)) + routeif.ifb.addstmt(failif) + + routeif.ifb.addstmt(StmtReturn(ExprCall( + ExprSelect(routedvar, '->', name), + args=[ ExprVar(p.name) for p in params ]))) + + method.addstmts([ routedecl, routeif, Whitespace.NL ]) + + # in the event of an Interrupt delete message, we want to loudly complain about + # messages that are received that are not a reply to the original message + if ptype.hasReentrantDelete: + msgVar = ExprVar(params[0].name) + ifdying = StmtIf(ExprBinary( + ExprBinary(ExprVar('mState'), '==', _dyingState(ptype)), + '&&', + ExprBinary( + ExprBinary(ExprCall(ExprSelect(msgVar, '.', 'is_reply')), '!=', ExprLiteral.TRUE), + '||', + ExprBinary(ExprCall(ExprSelect(msgVar, '.', 'is_interrupt')), '!=', ExprLiteral.TRUE)))) + ifdying.addifstmts([_fatalError('incoming message racing with actor deletion'), + StmtReturn(_Result.Processed)]) + method.addstmt(ifdying) + + # bug 509581: don't generate the switch stmt if there + # is only the default case; MSVC doesn't like that + if switch.nr_cases > 1: + method.addstmt(switch) + else: + method.addstmt(StmtReturn(_Result.NotKnown)) + + return method + + dispatches = (ptype.isToplevel() and ptype.isManager()) + self.cls.addstmts([ + makeHandlerMethod('OnMessageReceived', self.asyncSwitch, + hasReply=0, dispatches=dispatches), + Whitespace.NL + ]) + self.cls.addstmts([ + makeHandlerMethod('OnMessageReceived', self.syncSwitch, + hasReply=1, dispatches=dispatches), + Whitespace.NL + ]) + self.cls.addstmts([ + makeHandlerMethod('OnCallReceived', self.interruptSwitch, + hasReply=1, dispatches=dispatches), + Whitespace.NL + ]) + + destroysubtreevar = ExprVar('DestroySubtree') + deallocsubtreevar = ExprVar('DeallocSubtree') + deallocshmemvar = ExprVar('DeallocShmems') + deallocselfvar = ExprVar('Dealloc' + _actorName(ptype.name(), self.side)) + + # int32_t GetProtocolTypeId() { return PFoo; } + gettypetag = MethodDefn( + MethodDecl('GetProtocolTypeId', ret=_actorTypeTagType())) + gettypetag.addstmt(StmtReturn(_protocolId(ptype))) + self.cls.addstmts([ gettypetag, Whitespace.NL ]) + + if ptype.isToplevel(): + # OnChannelClose() + onclose = MethodDefn(MethodDecl('OnChannelClose')) + onclose.addstmts([ + StmtExpr(ExprCall(destroysubtreevar, + args=[ _DestroyReason.NormalShutdown ])), + StmtExpr(ExprCall(deallocsubtreevar)), + StmtExpr(ExprCall(deallocshmemvar)), + StmtExpr(ExprCall(deallocselfvar)) + ]) + self.cls.addstmts([ onclose, Whitespace.NL ]) + + # OnChannelError() + onerror = MethodDefn(MethodDecl('OnChannelError')) + onerror.addstmts([ + StmtExpr(ExprCall(destroysubtreevar, + args=[ _DestroyReason.AbnormalShutdown ])), + StmtExpr(ExprCall(deallocsubtreevar)), + StmtExpr(ExprCall(deallocshmemvar)), + StmtExpr(ExprCall(deallocselfvar)) + ]) + self.cls.addstmts([ onerror, Whitespace.NL ]) + + if (ptype.isToplevel() and ptype.isInterrupt()): + + processnative = MethodDefn( + MethodDecl('ProcessNativeEventsInInterruptCall', ret=Type.VOID)) + + processnative.addstmts([ + CppDirective('ifdef', 'OS_WIN'), + StmtExpr(ExprCall( + ExprSelect(p.channelVar(), '.', + 'ProcessNativeEventsInInterruptCall'))), + CppDirective('else'), + _fatalError('This method is Windows-only'), + CppDirective('endif'), + ]) + + self.cls.addstmts([ processnative, Whitespace.NL ]) + + ## private methods + self.cls.addstmt(Label.PRIVATE) + + ## ProtocolName() + actorname = _actorName(p.name, self.side) + protocolname = MethodDefn(MethodDecl( + 'ProtocolName', params=[], + const=1, virtual=1, ret=Type('char', const=1, ptr=1))) + protocolname.addstmts([ + StmtReturn(ExprLiteral.String(actorname)) + ]) + self.cls.addstmts([ protocolname, Whitespace.NL ]) + + ## DestroySubtree(bool normal) + whyvar = ExprVar('why') + subtreewhyvar = ExprVar('subtreewhy') + kidsvar = ExprVar('kids') + ivar = ExprVar('i') + itervar = ExprVar('iter') + ithkid = ExprIndex(kidsvar, ivar) + + destroysubtree = MethodDefn(MethodDecl( + destroysubtreevar.name, + params=[ Decl(_DestroyReason.Type(), whyvar.name) ])) + + if ptype.isManaged(): + destroysubtree.addstmt( + Whitespace('// Unregister from our manager.\n', indent=1)) + destroysubtree.addstmts(self.unregisterActor()) + destroysubtree.addstmt(Whitespace.NL) + + if ptype.isManager(): + # only declare this for managers to avoid unused var warnings + destroysubtree.addstmts([ + StmtDecl( + Decl(_DestroyReason.Type(), subtreewhyvar.name), + init=ExprConditional( + ExprBinary( + ExprBinary(whyvar, '==', + _DestroyReason.Deletion), + '||', + ExprBinary(whyvar, '==', + _DestroyReason.FailedConstructor)), + _DestroyReason.AncestorDeletion, whyvar)), + Whitespace.NL + ]) + + for managed in ptype.manages: + managedVar = p.managedVar(managed, self.side) + lenvar = ExprVar('len') + kidvar = ExprVar('kid') + + foreachdestroy = StmtRangedFor(kidvar, kidsvar) + + foreachdestroy.addstmt( + Whitespace('// Guarding against a child removing a sibling from the list during the iteration.\n', indent=1)) + ifhas = StmtIf(_callHasManagedActor(managedVar, kidvar)) + ifhas.addifstmt(StmtExpr(ExprCall( + ExprSelect(kidvar, '->', destroysubtreevar.name), + args=[ subtreewhyvar ]))) + foreachdestroy.addstmt(ifhas) + + block = StmtBlock() + block.addstmts([ + Whitespace( + '// Recursively shutting down %s kids\n'% (managed.name()), + indent=1), + StmtDecl( + Decl(_cxxArrayType(p.managedCxxType(managed, self.side)), kidsvar.name)), + Whitespace( + '// Accumulate kids into a stable structure to iterate over\n', + indent=1), + StmtExpr(ExprCall(p.managedMethod(managed, self.side), + args=[ kidsvar ])), + foreachdestroy, + ]) + destroysubtree.addstmt(block) + + if len(ptype.manages): + destroysubtree.addstmt(Whitespace.NL) + destroysubtree.addstmts([ Whitespace('// Finally, destroy "us".\n', + indent=1), + StmtExpr(ExprCall(_destroyMethod(), + args=[ whyvar ])) + ]) + + self.cls.addstmts([ destroysubtree, Whitespace.NL ]) + + ## DeallocSubtree() + deallocsubtree = MethodDefn(MethodDecl(deallocsubtreevar.name)) + for managed in ptype.manages: + managedVar = p.managedVar(managed, self.side) + + foreachrecurse = forLoopOverHashtable(managedVar, itervar) + foreachrecurse.addstmt(StmtExpr(ExprCall( + ExprSelect(actorFromIter(itervar), '->', deallocsubtreevar.name)))) + + foreachdealloc = forLoopOverHashtable(managedVar, itervar) + foreachdealloc.addstmts([ + StmtExpr(ExprCall(_deallocMethod(managed, self.side), + args=[ actorFromIter(itervar) ])) + ]) + + block = StmtBlock() + block.addstmts([ + Whitespace( + '// Recursively deleting %s kids\n'% (managed.name()), + indent=1), + foreachrecurse, + Whitespace.NL, + foreachdealloc, + StmtExpr(_callClearManagedActors(managedVar)), + + ]) + deallocsubtree.addstmt(block) + # don't delete outselves: either the manager will do it, or + # we're toplevel + self.cls.addstmts([ deallocsubtree, Whitespace.NL ]) + + if ptype.isToplevel(): + deallocself = MethodDefn(MethodDecl(deallocselfvar.name, virtual=1)) + self.cls.addstmts([ deallocself, Whitespace.NL ]) + + self.implementPickling() + + ## private members + if ptype.isToplevel(): + self.cls.addstmt(StmtDecl(Decl(p.channelType(), 'mChannel'))) + + self.cls.addstmt(StmtDecl(Decl(Type('State'), p.stateVar().name))) + + for managed in ptype.manages: + self.cls.addstmts([ + StmtDecl(Decl( + p.managedVarType(managed, self.side), + p.managedVar(managed, self.side).name)) ]) + + def implementManagerIface(self): + p = self.protocol + routedvar = ExprVar('aRouted') + idvar = ExprVar('aId') + shmemvar = ExprVar('shmem') + rawvar = ExprVar('segment') + sizevar = ExprVar('aSize') + typevar = ExprVar('aType') + unsafevar = ExprVar('aUnsafe') + protocolbase = Type('ProtocolBase', ptr=1) + sourcevar = ExprVar('aSource') + ivar = ExprVar('i') + kidsvar = ExprVar('kids') + ithkid = ExprIndex(kidsvar, ivar) + + methods = [] + + if p.decl.type.isToplevel(): + getchannel = MethodDefn(MethodDecl( + p.getChannelMethod().name, + ret=Type('MessageChannel', ptr=1), + virtual=1)) + getchannel.addstmt(StmtReturn(ExprAddrOf(p.channelVar()))) + + getchannelconst = MethodDefn(MethodDecl( + p.getChannelMethod().name, + ret=Type('MessageChannel', ptr=1, const=1), + virtual=1, const=1)) + getchannelconst.addstmt(StmtReturn(ExprAddrOf(p.channelVar()))) + + methods += [ getchannel, + getchannelconst ] + + if p.decl.type.isToplevel(): + tmpvar = ExprVar('tmp') + + # "private" message that passes shmem mappings from one process + # to the other + if p.subtreeUsesShmem(): + self.asyncSwitch.addcase( + CaseLabel('SHMEM_CREATED_MESSAGE_TYPE'), + self.genShmemCreatedHandler()) + self.asyncSwitch.addcase( + CaseLabel('SHMEM_DESTROYED_MESSAGE_TYPE'), + self.genShmemDestroyedHandler()) + else: + abort = StmtBlock() + abort.addstmts([ + _fatalError('this protocol tree does not use shmem'), + StmtReturn(_Result.NotKnown) + ]) + self.asyncSwitch.addcase( + CaseLabel('SHMEM_CREATED_MESSAGE_TYPE'), abort) + self.asyncSwitch.addcase( + CaseLabel('SHMEM_DESTROYED_MESSAGE_TYPE'), abort) + + othervar = ExprVar('other') + managertype = Type(_actorName(p.name, self.side), ptr=1) + + # Keep track of types created with an INOUT ctor. We need to call + # Register() or RegisterID() for them depending on the side the managee + # is created. + inoutCtorTypes = [] + for msg in p.messageDecls: + msgtype = msg.decl.type + if msgtype.isCtor() and msgtype.isInout(): + inoutCtorTypes.append(msgtype.constructedType()) + + # all protocols share the "same" RemoveManagee() implementation + pvar = ExprVar('aProtocolId') + listenervar = ExprVar('aListener') + removemanagee = MethodDefn(MethodDecl( + p.removeManageeMethod().name, + params=[ Decl(_protocolIdType(), pvar.name), + Decl(protocolbase, listenervar.name) ], + virtual=1)) + + if not len(p.managesStmts): + removemanagee.addstmts([ _fatalError('unreached'), StmtReturn() ]) + else: + switchontype = StmtSwitch(pvar) + for managee in p.managesStmts: + case = StmtBlock() + actorvar = ExprVar('actor') + manageeipdltype = managee.decl.type + manageecxxtype = _cxxBareType(ipdl.type.ActorType(manageeipdltype), + self.side) + manageearray = p.managedVar(manageeipdltype, self.side) + containervar = ExprVar('container') + + case.addstmts([ + StmtDecl(Decl(manageecxxtype, actorvar.name), + ExprCast(listenervar, manageecxxtype, static=1)), + # Use a temporary variable here so all the assertion expressions + # in the _abortIfFalse call below are textually identical; the + # linker can then merge the strings from the assertion macro(s). + StmtDecl(Decl(Type('auto', ref=1), containervar.name), + manageearray), + _abortIfFalse( + _callHasManagedActor(containervar, actorvar), + "actor not managed by this!"), + Whitespace.NL, + StmtExpr(_callRemoveManagedActor(containervar, actorvar)), + StmtExpr(ExprCall(_deallocMethod(manageeipdltype, self.side), + args=[ actorvar ])), + StmtReturn() + ]) + switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), + case) + default = StmtBlock() + default.addstmts([ _fatalError('unreached'), StmtReturn() ]) + switchontype.addcase(DefaultLabel(), default) + removemanagee.addstmt(switchontype) + + return methods + [removemanagee, Whitespace.NL] + + def genShmemCreatedHandler(self): + p = self.protocol + assert p.decl.type.isToplevel() + + case = StmtBlock() + + ifstmt = StmtIf(ExprNot(ExprCall(ExprVar('ShmemCreated'), args=[self.msgvar]))) + case.addstmts([ + ifstmt, + StmtReturn(_Result.Processed) + ]) + ifstmt.addifstmt(StmtReturn(_Result.PayloadError)) + + return case + + def genShmemDestroyedHandler(self): + p = self.protocol + assert p.decl.type.isToplevel() + + case = StmtBlock() + + ifstmt = StmtIf(ExprNot(ExprCall(ExprVar('ShmemDestroyed'), args=[self.msgvar]))) + case.addstmts([ + ifstmt, + StmtReturn(_Result.Processed) + ]) + ifstmt.addifstmt(StmtReturn(_Result.PayloadError)) + + return case + + + def makeChannelOpenedHandlers(self, actors): + handlers = StmtBlock() + + # unpack the transport descriptor et al. + msgvar = self.msgvar + tdvar = ExprVar('td') + pidvar = ExprVar('pid') + pvar = ExprVar('protocolid') + iffail = StmtIf(ExprNot(ExprCall( + ExprVar('mozilla::ipc::UnpackChannelOpened'), + args=[ _backstagePass(), + msgvar, + ExprAddrOf(tdvar), ExprAddrOf(pidvar), ExprAddrOf(pvar) ]))) + iffail.addifstmt(StmtReturn(_Result.PayloadError)) + handlers.addstmts([ + StmtDecl(Decl(Type('TransportDescriptor'), tdvar.name)), + StmtDecl(Decl(Type('ProcessId'), pidvar.name)), + StmtDecl(Decl(Type('ProtocolId'), pvar.name)), + iffail, + Whitespace.NL + ]) + + def makeHandlerCase(actor): + self.protocolCxxIncludes.append(_protocolHeaderName(actor.ptype._ast, + actor.side)) + + case = StmtBlock() + modevar = _sideToTransportMode(actor.side) + tvar = ExprVar('t') + iffailopen = StmtIf(ExprNot(ExprAssn( + tvar, + ExprCall(ExprVar('mozilla::ipc::OpenDescriptor'), + args=[ tdvar, modevar ])))) + iffailopen.addifstmt(StmtReturn(_Result.ValuError)) + + pvar = ExprVar('p') + iffailalloc = StmtIf(ExprNot(ExprAssn( + pvar, + ExprCall( + _allocMethod(actor.ptype, actor.side), + args=[ _uniqueptrGet(tvar), pidvar ])))) + iffailalloc.addifstmt(StmtReturn(_Result.ProcessingError)) + + settrans = StmtExpr(ExprCall( + ExprSelect(pvar, '->', 'IToplevelProtocol::SetTransport'), + args=[ExprMove(tvar)])) + + case.addstmts([ + StmtDecl(Decl(_uniqueptr(Type('Transport')), tvar.name)), + StmtDecl(Decl(Type(_actorName(actor.ptype.name(), actor.side), + ptr=1), pvar.name)), + iffailopen, + iffailalloc, + settrans, + StmtBreak() + ]) + label = _messageStartName(actor.ptype) + if actor.side == 'child': + label += 'Child' + return CaseLabel(label), case + + pswitch = StmtSwitch(pvar) + for actor in actors: + label, case = makeHandlerCase(actor) + pswitch.addcase(label, case) + + die = Block() + die.addstmts([ _fatalError('Invalid protocol'), + StmtReturn(_Result.ValuError) ]) + pswitch.addcase(DefaultLabel(), die) + + handlers.addstmts([ + pswitch, + StmtReturn(_Result.Processed) + ]) + self.asyncSwitch.addcase(CaseLabel('CHANNEL_OPENED_MESSAGE_TYPE'), + handlers) + + ##------------------------------------------------------------------------- + ## The next few functions are the crux of the IPDL code generator. + ## They generate code for all the nasty work of message + ## serialization/deserialization and dispatching handlers for + ## received messages. + ## + def implementPickling(self): + # pickling of "normal", non-IPDL types + self.implementGenericPickling() + + # pickling for IPDL types + specialtypes = set() + class findSpecialTypes(TypeVisitor): + def visitActorType(self, a): specialtypes.add(a) + def visitShmemType(self, s): specialtypes.add(s) + def visitFDType(self, s): specialtypes.add(s) + def visitStructType(self, s): + specialtypes.add(s) + return TypeVisitor.visitStructType(self, s) + def visitUnionType(self, u): + specialtypes.add(u) + return TypeVisitor.visitUnionType(self, u) + def visitArrayType(self, a): + if a.basetype.isIPDL(): + specialtypes.add(a) + return a.basetype.accept(self) + + for md in self.protocol.messageDecls: + for param in md.params: + mtype = md.decl.type + # special case for top-level __delete__(), which isn't + # understood yet + if mtype.isDtor() and mtype.constructedType().isToplevel(): + continue + param.ipdltype.accept(findSpecialTypes()) + for ret in md.returns: + ret.ipdltype.accept(findSpecialTypes()) + + for t in specialtypes: + if t.isActor(): self.implementActorPickling(t) + elif t.isArray(): self.implementSpecialArrayPickling(t) + elif t.isShmem(): self.implementShmemPickling(t) + elif t.isFD(): self.implementFDPickling(t) + elif t.isStruct(): self.implementStructPickling(t) + elif t.isUnion(): self.implementUnionPickling(t) + else: + assert 0 and 'unknown special type' + + def implementGenericPickling(self): + var = self.var + msgvar = self.msgvar + itervar = self.itervar + + write = MethodDefn(self.writeMethodDecl( + Type('T', const=1, ref=1), var, template=Type('T'))) + write.addstmt(StmtExpr(ExprCall(ExprVar('IPC::WriteParam'), + args=[ msgvar, var ]))) + + read = MethodDefn(self.readMethodDecl( + Type('T', ptr=1), var, template=Type('T'))) + read.addstmt(StmtReturn(ExprCall(ExprVar('IPC::ReadParam'), + args=[ msgvar, itervar, var ]))) + + self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ]) + + def implementActorPickling(self, actortype): + # Note that we pickle based on *protocol* type and *not* actor + # type. The actor type includes a |nullable| qualifier, but + # this method is not specialized based on nullability. The + # |actortype| nullability is ignored in this method. + var = self.var + idvar = ExprVar('id') + intype = _cxxConstRefType(actortype, self.side) + # XXX the writer code can treat the actor as logically const; many + # other places that call _cxxConstRefType cannot treat the actor + # as logically const, particularly callers that can leak out to + # Gecko directly. + intype.const = 1 + cxxtype = _cxxBareType(actortype, self.side) + outtype = _cxxPtrToType(actortype, self.side) + + ## Write([const] PFoo* var) + write = MethodDefn(self.writeMethodDecl(intype, var)) + nullablevar = ExprVar('nullable__') + write.decl.params.append(Decl(Type.BOOL, nullablevar.name)) + # id_t id; + # if (!var) + # if(!nullable) + # abort() + # id = NULL_ID + write.addstmt(StmtDecl(Decl(_actorIdType(), idvar.name))) + + ifnull = StmtIf(ExprNot(var)) + ifnotnullable = StmtIf(ExprNot(nullablevar)) + ifnotnullable.addifstmt( + _fatalError("NULL actor value passed to non-nullable param")) + ifnull.addifstmt(ifnotnullable) + ifnull.addifstmt(StmtExpr(ExprAssn(idvar, _NULL_ACTOR_ID))) + # else + # id = var->mId + # if (id == FREED_ID) + # abort() + # Write(msg, id) + ifnull.addelsestmt(StmtExpr(ExprAssn(idvar, _actorId(var)))) + iffreed = StmtIf(ExprBinary(_FREED_ACTOR_ID, '==', idvar)) + # this is always a hard-abort, because it means that some C++ + # code has a live pointer to a freed actor, so we're playing + # Russian roulette with invalid memory + iffreed.addifstmt(_fatalError("actor has been |delete|d")) + ifnull.addelsestmt(iffreed) + + write.addstmts([ + ifnull, + Whitespace.NL, + StmtExpr(self.write(None, idvar, self.msgvar)) + ]) + + ## Read(PFoo** var) + read = MethodDefn(self.readMethodDecl(outtype, var)) + read.decl.params.append(Decl(Type.BOOL, nullablevar.name)) + + actorvar = ExprVar('actor') + read.addstmts([ + StmtDecl(Decl(Type('Maybe', T=Type('mozilla::ipc::IProtocol', ptr=1)), actorvar.name), + init=ExprCall(ExprVar('ReadActor'), + args=[ self.msgvar, self.itervar, nullablevar, + ExprLiteral.String(actortype.name()), + _protocolId(actortype) ])), + ]) + + # if (actor.isNothing()) + # return false + # + # Reading the actor failed in some way, and the appropriate error was raised. + ifnothing = StmtIf(ExprCall(ExprSelect(actorvar, '.', 'isNothing'))) + ifnothing.addifstmts([ + StmtReturn.FALSE, + ]) + + read.addstmts([ ifnothing, Whitespace.NL ]) + + read.addstmts([ + StmtExpr(ExprAssn(ExprDeref(var), + ExprCast(ExprCall(ExprSelect(actorvar, '.', 'value')), cxxtype, static=1))), + StmtReturn.TRUE + ]) + + self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ]) + + + def implementSpecialArrayPickling(self, arraytype): + var = self.var + msgvar = self.msgvar + itervar = self.itervar + lenvar = ExprVar('length') + ivar = ExprVar('i') + eltipdltype = arraytype.basetype + intype = _cxxConstRefType(arraytype, self.side) + outtype = _cxxPtrToType(arraytype, self.side) + + # We access elements directly in Read and Write to avoid array bounds + # checking. + directtype = _cxxBareType(arraytype.basetype, self.side) + if directtype.ptr: + typeinit = { 'ptrptr': 1 } + else: + typeinit = { 'ptr': 1 } + directtype = Type(directtype.name, **typeinit) + elemsvar = ExprVar('elems') + elemvar = ExprVar('elem') + + write = MethodDefn(self.writeMethodDecl(intype, var)) + forwrite = StmtRangedFor(elemvar, var) + forwrite.addstmt( + self.checkedWrite(eltipdltype, elemvar, msgvar, + sentinelKey=arraytype.name())) + write.addstmts([ + StmtDecl(Decl(Type.UINT32, lenvar.name), + init=_callCxxArrayLength(var)), + self.checkedWrite(None, lenvar, msgvar, sentinelKey=('length', arraytype.name())), + Whitespace.NL, + forwrite + ]) + + read = MethodDefn(self.readMethodDecl(outtype, var)) + favar = ExprVar('fa') + forread = StmtFor(init=ExprAssn(Decl(Type.UINT32, ivar.name), + ExprLiteral.ZERO), + cond=ExprBinary(ivar, '<', lenvar), + update=ExprPrefixUnop(ivar, '++')) + forread.addstmt( + self.checkedRead(eltipdltype, ExprAddrOf(ExprIndex(elemsvar, ivar)), + msgvar, itervar, errfnRead, + '\'' + eltipdltype.name() + '[i]\'', + sentinelKey=arraytype.name())) + appendstmt = StmtDecl(Decl(directtype, elemsvar.name), + init=ExprCall(ExprSelect(favar, '.', 'AppendElements'), + args=[ lenvar ])) + read.addstmts([ + StmtDecl(Decl(_cxxArrayType(_cxxBareType(arraytype.basetype, self.side)), favar.name)), + StmtDecl(Decl(Type.UINT32, lenvar.name)), + self.checkedRead(None, ExprAddrOf(lenvar), + msgvar, itervar, errfnArrayLength, + [ arraytype.name() ], + sentinelKey=('length', arraytype.name())), + Whitespace.NL, + appendstmt, + forread, + StmtExpr(_callCxxSwapArrayElements(var, favar, '->')), + StmtReturn.TRUE + ]) + + self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ]) + + + def implementShmemPickling(self, shmemtype): + msgvar = self.msgvar + itervar = self.itervar + var = self.var + tmpvar = ExprVar('tmp') + idvar = ExprVar('shmemid') + rawvar = ExprVar('rawmem') + baretype = _cxxBareType(shmemtype, self.side) + intype = _cxxConstRefType(shmemtype, self.side) + outtype = _cxxPtrToType(shmemtype, self.side) + + write = MethodDefn(self.writeMethodDecl(intype, var)) + write.addstmts([ + StmtExpr(ExprCall(ExprVar('IPC::WriteParam'), + args=[ msgvar, var ])), + StmtExpr(_shmemRevokeRights(var)), + StmtExpr(_shmemForget(var)) + ]) + + read = MethodDefn(self.readMethodDecl(outtype, var)) + ifread = StmtIf(ExprNot(ExprCall(ExprVar('IPC::ReadParam'), + args=[ msgvar, itervar, + ExprAddrOf(tmpvar) ]))) + ifread.addifstmt(StmtReturn.FALSE) + + iffound = StmtIf(rawvar) + iffound.addifstmt(StmtExpr(ExprAssn( + ExprDeref(var), _shmemCtor(rawvar, idvar)))) + iffound.addifstmt(StmtReturn.TRUE) + + read.addstmts([ + StmtDecl(Decl(_shmemType(), tmpvar.name)), + ifread, + Whitespace.NL, + StmtDecl(Decl(_shmemIdType(), idvar.name), + init=_shmemId(tmpvar)), + StmtDecl(Decl(_rawShmemType(ptr=1), rawvar.name), + init=_lookupShmem(idvar)), + iffound, + # This is ugly: we failed to look the shmem up, most likely because + # we failed to map it the first time it was deserialized. we create + # an empty shmem and let the user of the shmem deal with it. + # if we returned false here we would crash. + StmtExpr(ExprAssn(ExprDeref(var), ExprCall(ExprVar('Shmem'), args=[]) )), + StmtReturn.TRUE + ]) + + self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ]) + + def implementFDPickling(self, fdtype): + msgvar = self.msgvar + itervar = self.itervar + var = self.var + tmpvar = ExprVar('fd') + picklevar = ExprVar('pfd') + intype = _cxxConstRefType(fdtype, self.side) + outtype = _cxxPtrToType(fdtype, self.side) + + def _fdType(): + return Type('FileDescriptor') + + def _fdPickleType(): + return Type('FileDescriptor::PickleType') + + def _fdBackstagePass(): + return ExprCall(ExprVar('FileDescriptor::IPDLPrivate')) + + write = MethodDefn(self.writeMethodDecl(intype, var)) + write.addstmts([ + StmtDecl(Decl(_fdPickleType(), picklevar.name), + init=ExprCall(ExprSelect(var, '.', 'ShareTo'), + args=[ _fdBackstagePass(), + self.protocol.callOtherPid() ])), + StmtExpr(ExprCall(ExprVar('IPC::WriteParam'), + args=[ msgvar, picklevar ])), + ]) + + read = MethodDefn(self.readMethodDecl(outtype, var)) + ifread = StmtIf(ExprNot(ExprCall(ExprVar('IPC::ReadParam'), + args=[ msgvar, itervar, + ExprAddrOf(picklevar) ]))) + ifread.addifstmt(StmtReturn.FALSE) + + ifnvalid = StmtIf(ExprNot(ExprCall(ExprSelect(tmpvar, '.', 'IsValid')))) + ifnvalid.addifstmt( + _protocolErrorBreakpoint('[' + + _actorName(self.protocol.name, self.side) + + '] Received an invalid file descriptor!')) + + read.addstmts([ + StmtDecl(Decl(_fdPickleType(), picklevar.name)), + ifread, + Whitespace.NL, + StmtDecl(Decl(_fdType(), tmpvar.name), + init=ExprCall(ExprVar('FileDescriptor'), + args=[ _fdBackstagePass(), picklevar ])), + ifnvalid, + Whitespace.NL, + StmtExpr(ExprAssn(ExprDeref(var), tmpvar)), + StmtReturn.TRUE + ]) + + self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ]) + + def implementStructPickling(self, structtype): + msgvar = self.msgvar + itervar = self.itervar + var = self.var + intype = _cxxConstRefType(structtype, self.side) + outtype = _cxxPtrToType(structtype, self.side) + sd = structtype._ast + + write = MethodDefn(self.writeMethodDecl(intype, var)) + read = MethodDefn(self.readMethodDecl(outtype, var)) + + def get(sel, f): + return ExprCall(f.getMethod(thisexpr=var, sel=sel)) + + for f in sd.fields: + desc = '\'' + f.getMethod().name + '\' (' + f.ipdltype.name() + \ + ') member of \'' + intype.name + '\'' + writefield = self.checkedWrite(f.ipdltype, get('.', f), msgvar, sentinelKey=f.basename) + readfield = self.checkedRead(f.ipdltype, + ExprAddrOf(get('->', f)), + msgvar, itervar, errfnRead, desc, sentinelKey=f.basename) + if f.special and f.side != self.side: + writefield = Whitespace( + "// skipping actor field that's meaningless on this side\n", indent=1) + readfield = Whitespace( + "// skipping actor field that's meaningless on this side\n", indent=1) + write.addstmt(writefield) + read.addstmt(readfield) + + read.addstmt(StmtReturn.TRUE) + + self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ]) + + + def implementUnionPickling(self, uniontype): + msgvar = self.msgvar + itervar = self.itervar + var = self.var + intype = _cxxConstRefType(uniontype, self.side) + outtype = _cxxPtrToType(uniontype, self.side) + ud = uniontype._ast + + typename = 'type__' + uniontdef = Typedef(_cxxBareType(uniontype, typename), typename) + + typevar = ExprVar('type') + writeswitch = StmtSwitch(ud.callType(var)) + readswitch = StmtSwitch(typevar) + + for c in ud.components: + ct = c.ipdltype + isactor = (ct.isIPDL() and ct.isActor()) + caselabel = CaseLabel(typename +'::'+ c.enum()) + origenum = c.enum() + + writecase = StmtBlock() + if c.special and c.side != self.side: + writecase.addstmt(_fatalError('wrong side!')) + else: + wexpr = ExprCall(ExprSelect(var, '.', c.getTypeName())) + writecase.addstmt(self.checkedWrite(ct, wexpr, msgvar, sentinelKey=c.enum())) + + writecase.addstmt(StmtReturn()) + writeswitch.addcase(caselabel, writecase) + + readcase = StmtBlock() + if c.special and c.side == self.side: + # the type comes across flipped from what the actor + # will be on this side; i.e. child->parent messages + # have type PFooChild when received on the parent side + # XXX: better error message + readcase.addstmt(StmtReturn.FALSE) + else: + if c.special: + c = c.other # see above + tmpvar = ExprVar('tmp') + ct = c.bareType() + readcase.addstmts([ + StmtDecl(Decl(ct, tmpvar.name), init=c.defaultValue()), + StmtExpr(ExprAssn(ExprDeref(var), tmpvar)), + self.checkedRead( + c.ipdltype, + ExprAddrOf(ExprCall(ExprSelect(var, '->', + c.getTypeName()))), + msgvar, itervar, errfnRead, 'Union type', sentinelKey=origenum), + StmtReturn(ExprLiteral.TRUE) + ]) + + readswitch.addcase(caselabel, readcase) + + unknowntype = 'unknown union type' + writeswitch.addcase(DefaultLabel(), + StmtBlock([ _fatalError(unknowntype), + StmtReturn() ])) + readswitch.addcase(DefaultLabel(), StmtBlock(errfnRead(unknowntype))) + + write = MethodDefn(self.writeMethodDecl(intype, var)) + write.addstmts([ + uniontdef, + self.checkedWrite( + None, ExprCall(Type.INT, args=[ ud.callType(var) ]), msgvar, + sentinelKey=uniontype.name()), + Whitespace.NL, + writeswitch + ]) + + read = MethodDefn(self.readMethodDecl(outtype, var)) + read.addstmts([ + uniontdef, + StmtDecl(Decl(Type.INT, typevar.name)), + self.checkedRead( + None, ExprAddrOf(typevar), msgvar, itervar, errfnUnionType, + [ uniontype.name() ], + sentinelKey=uniontype.name()), + Whitespace.NL, + readswitch, + ]) + + self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ]) + + + def writeMethodDecl(self, intype, var, template=None): + return MethodDecl( + 'Write', + params=[ Decl(intype, var.name), + Decl(Type('Message', ptr=1), self.msgvar.name) ], + T=template) + + def readMethodDecl(self, outtype, var, template=None): + return MethodDecl( + 'Read', + params=[ Decl(outtype, var.name), + Decl(Type('Message', ptr=1, const=1), + self.msgvar.name), + Decl(_iterType(ptr=1), self.itervar.name)], + warn_unused=not template, + T=template, + ret=Type.BOOL) + + def maybeAddNullabilityArg(self, ipdltype, call): + if ipdltype and ipdltype.isIPDL() and ipdltype.isActor(): + if ipdltype.nullable: + call.args.append(ExprLiteral.TRUE) + else: + call.args.append(ExprLiteral.FALSE) + return call + + def write(self, ipdltype, expr, to, this=None): + write = ExprVar('Write') + if this: write = ExprSelect(this, '->', write.name) + return self.maybeAddNullabilityArg(ipdltype, + ExprCall(write, args=[ expr, to ])) + + def read(self, ipdltype, expr, from_, iterexpr, this=None): + read = ExprVar('Read') + if this: read = ExprSelect(this, '->', read.name) + return self.maybeAddNullabilityArg( + ipdltype, ExprCall(read, args=[ expr, from_, iterexpr ])) + + def checkedWrite(self, ipdltype, expr, msgvar, sentinelKey, this=None): + assert sentinelKey + + write = StmtExpr(self.write(ipdltype, expr, msgvar, this)) + + sentinel = StmtExpr(ExprCall(ExprSelect(msgvar, '->', 'WriteSentinel'), + args=[ ExprLiteral.Int(hashfunc(sentinelKey)) ])) + block = Block() + block.addstmts([ + write, + Whitespace('// Sentinel = ' + repr(sentinelKey) + '\n', indent=1), + sentinel ]) + return block + + + def visitMessageDecl(self, md): + isctor = md.decl.type.isCtor() + isdtor = md.decl.type.isDtor() + decltype = md.decl.type + sendmethod = None + helpermethod = None + recvlbl, recvcase = None, None + + def addRecvCase(lbl, case): + if decltype.isAsync(): + self.asyncSwitch.addcase(lbl, case) + elif decltype.isSync(): + self.syncSwitch.addcase(lbl, case) + elif decltype.isInterrupt(): + self.interruptSwitch.addcase(lbl, case) + else: assert 0 + + if self.sendsMessage(md): + isasync = decltype.isAsync() + + if isctor: + self.cls.addstmts([ self.genHelperCtor(md), Whitespace.NL ]) + + if isctor and isasync: + sendmethod, (recvlbl, recvcase) = self.genAsyncCtor(md) + elif isctor: + sendmethod = self.genBlockingCtorMethod(md) + elif isdtor and isasync: + sendmethod, (recvlbl, recvcase) = self.genAsyncDtor(md) + elif isdtor: + sendmethod = self.genBlockingDtorMethod(md) + elif isasync: + sendmethod = self.genAsyncSendMethod(md) + else: + sendmethod = self.genBlockingSendMethod(md) + + # XXX figure out what to do here + if isdtor and md.decl.type.constructedType().isToplevel(): + sendmethod = None + + if sendmethod is not None: + self.cls.addstmts([ sendmethod, Whitespace.NL ]) + if recvcase is not None: + addRecvCase(recvlbl, recvcase) + recvlbl, recvcase = None, None + + if self.receivesMessage(md): + if isctor: + recvlbl, recvcase = self.genCtorRecvCase(md) + elif isdtor: + recvlbl, recvcase = self.genDtorRecvCase(md) + else: + recvlbl, recvcase = self.genRecvCase(md) + + # XXX figure out what to do here + if isdtor and md.decl.type.constructedType().isToplevel(): + return + + addRecvCase(recvlbl, recvcase) + + + def genAsyncCtor(self, md): + actor = md.actorDecl() + method = MethodDefn(self.makeSendMethodDecl(md)) + method.addstmts(self.ctorPrologue(md) + [ Whitespace.NL ]) + + msgvar, stmts = self.makeMessage(md, errfnSendCtor) + sendok, sendstmts = self.sendAsync(md, msgvar) + method.addstmts( + stmts + + self.genVerifyMessage(md.decl.type.verify, md.params, + errfnSendCtor, ExprVar('msg__')) + + sendstmts + + self.failCtorIf(md, ExprNot(sendok)) + + [ StmtReturn(actor.var()) ]) + + lbl = CaseLabel(md.pqReplyId()) + case = StmtBlock() + case.addstmt(StmtReturn(_Result.Processed)) + # TODO not really sure what to do with async ctor "replies" yet. + # destroy actor if there was an error? tricky ... + + return method, (lbl, case) + + + def genBlockingCtorMethod(self, md): + actor = md.actorDecl() + method = MethodDefn(self.makeSendMethodDecl(md)) + method.addstmts(self.ctorPrologue(md) + [ Whitespace.NL ]) + + msgvar, stmts = self.makeMessage(md, errfnSendCtor) + + replyvar = self.replyvar + sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar) + method.addstmts( + stmts + + [ Whitespace.NL, + StmtDecl(Decl(Type('Message'), replyvar.name)) ] + + self.genVerifyMessage(md.decl.type.verify, md.params, + errfnSendCtor, ExprVar('msg__')) + + sendstmts + + self.failCtorIf(md, ExprNot(sendok))) + + def errfnCleanupCtor(msg): + return self.failCtorIf(md, ExprLiteral.TRUE) + stmts = self.deserializeReply( + md, ExprAddrOf(replyvar), self.side, errfnCleanupCtor) + method.addstmts(stmts + [ StmtReturn(actor.var()) ]) + + return method + + + def ctorPrologue(self, md, errfn=ExprLiteral.NULL, idexpr=None): + actordecl = md.actorDecl() + actorvar = actordecl.var() + actorproto = actordecl.ipdltype.protocol + actortype = ipdl.type.ActorType(actorproto) + + if idexpr is None: + idexpr = ExprCall(self.protocol.registerMethod(), + args=[ actorvar ]) + else: + idexpr = ExprCall(self.protocol.registerIDMethod(), + args=[ actorvar, idexpr ]) + + return [ + self.failIfNullActor(actorvar, errfn, msg="Error constructing actor %s" % actortype.name() + self.side.capitalize()), + StmtExpr(ExprCall(ExprSelect(actorvar, '->', 'SetId'), args=[idexpr])), + StmtExpr(ExprCall(ExprSelect(actorvar, '->', 'SetManager'), args=[ExprVar.THIS])), + StmtExpr(ExprCall(ExprSelect(actorvar, '->', 'SetIPCChannel'), + args=[self.protocol.callGetChannel()])), + StmtExpr(_callInsertManagedActor( + self.protocol.managedVar(md.decl.type.constructedType(), + self.side), + actorvar)), + StmtExpr(ExprAssn(_actorState(actorvar), + _startState(actorproto, fq=1))) + ] + + def failCtorIf(self, md, cond): + actorvar = md.actorDecl().var() + type = md.decl.type.constructedType() + failif = StmtIf(cond) + + if self.side=='child': + # in the child process this should not fail + failif.addifstmt(_fatalError('constructor for actor failed')) + else: + failif.addifstmts(self.destroyActor(md, actorvar, + why=_DestroyReason.FailedConstructor)) + + failif.addifstmt(StmtReturn(ExprLiteral.NULL)) + return [ failif ] + + def genHelperCtor(self, md): + helperdecl = self.makeSendMethodDecl(md) + helperdecl.params = helperdecl.params[1:] + helper = MethodDefn(helperdecl) + + callctor = self.callAllocActor(md, retsems='out', side=self.side) + helper.addstmt(StmtReturn(ExprCall( + ExprVar(helperdecl.name), args=[ callctor ] + callctor.args))) + return helper + + + def genAsyncDtor(self, md): + actor = md.actorDecl() + actorvar = actor.var() + method = MethodDefn(self.makeDtorMethodDecl(md)) + + method.addstmts(self.dtorPrologue(actorvar)) + + msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar) + sendok, sendstmts = self.sendAsync(md, msgvar, actorvar) + method.addstmts( + stmts + + self.genVerifyMessage(md.decl.type.verify, md.params, + errfnSendDtor, ExprVar('msg__')) + + sendstmts + + [ Whitespace.NL ] + + self.dtorEpilogue(md, actor.var()) + + [ StmtReturn(sendok) ]) + + lbl = CaseLabel(md.pqReplyId()) + case = StmtBlock() + case.addstmt(StmtReturn(_Result.Processed)) + # TODO if the dtor is "inherently racy", keep the actor alive + # until the other side acks + + return method, (lbl, case) + + + def genBlockingDtorMethod(self, md): + actor = md.actorDecl() + actorvar = actor.var() + method = MethodDefn(self.makeDtorMethodDecl(md)) + + method.addstmts(self.dtorPrologue(actorvar)) + + msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar) + + replyvar = self.replyvar + sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar, actorvar) + method.addstmts( + stmts + + self.genVerifyMessage(md.decl.type.verify, md.params, + errfnSendDtor, ExprVar('msg__')) + + [ Whitespace.NL, + StmtDecl(Decl(Type('Message'), replyvar.name)) ] + + sendstmts) + + destmts = self.deserializeReply( + md, ExprAddrOf(replyvar), self.side, errfnSend, actorvar) + ifsendok = StmtIf(ExprLiteral.FALSE) + ifsendok.addifstmts(destmts) + ifsendok.addifstmts([ Whitespace.NL, + StmtExpr(ExprAssn(sendok, ExprLiteral.FALSE, '&=')) ]) + + method.addstmt(ifsendok) + + if self.protocol.decl.type.hasReentrantDelete: + method.addstmts(self.transition(md, 'in', actor.var(), reply=True)) + + method.addstmts( + self.dtorEpilogue(md, actor.var()) + + [ Whitespace.NL, StmtReturn(sendok) ]) + + return method + + def destroyActor(self, md, actorexpr, why=_DestroyReason.Deletion): + if md.decl.type.isCtor(): + destroyedType = md.decl.type.constructedType() + else: + destroyedType = self.protocol.decl.type + return ([ StmtExpr(self.callActorDestroy(actorexpr, why)), + StmtExpr(self.callDeallocSubtree(md, actorexpr)), + StmtExpr(self.callRemoveActor( + actorexpr, + manager=self.protocol.managerVar(actorexpr), + ipdltype=destroyedType)) + ]) + + def dtorPrologue(self, actorexpr): + return [ self.failIfNullActor(actorexpr), Whitespace.NL ] + + def dtorEpilogue(self, md, actorexpr): + return self.destroyActor(md, actorexpr) + + def genAsyncSendMethod(self, md): + method = MethodDefn(self.makeSendMethodDecl(md)) + msgvar, stmts = self.makeMessage(md, errfnSend) + sendok, sendstmts = self.sendAsync(md, msgvar) + method.addstmts(stmts + +[ Whitespace.NL ] + + self.genVerifyMessage(md.decl.type.verify, md.params, + errfnSend, ExprVar('msg__')) + + sendstmts + +[ StmtReturn(sendok) ]) + return method + + + def genBlockingSendMethod(self, md, fromActor=None): + method = MethodDefn(self.makeSendMethodDecl(md)) + + msgvar, serstmts = self.makeMessage(md, errfnSend, fromActor) + replyvar = self.replyvar + + sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar) + failif = StmtIf(ExprNot(sendok)) + failif.addifstmt(StmtReturn.FALSE) + + desstmts = self.deserializeReply( + md, ExprAddrOf(replyvar), self.side, errfnSend) + + method.addstmts( + serstmts + + self.genVerifyMessage(md.decl.type.verify, md.params, errfnSend, + ExprVar('msg__')) + + [ Whitespace.NL, + StmtDecl(Decl(Type('Message'), replyvar.name)) ] + + sendstmts + + [ failif ] + + desstmts + + [ Whitespace.NL, + StmtReturn.TRUE ]) + + return method + + + def genCtorRecvCase(self, md): + lbl = CaseLabel(md.pqMsgId()) + case = StmtBlock() + actorvar = md.actorDecl().var() + actorhandle = self.handlevar + + stmts = self.deserializeMessage(md, self.side, errfnRecv) + + idvar, saveIdStmts = self.saveActorId(md) + case.addstmts( + stmts + + self.transition(md, 'in') + + [ StmtDecl(Decl(r.bareType(self.side), r.var().name)) + for r in md.returns ] + # alloc the actor, register it under the foreign ID + + [ StmtExpr(ExprAssn( + actorvar, + self.callAllocActor(md, retsems='in', side=self.side))) ] + + self.ctorPrologue(md, errfn=_Result.ValuError, + idexpr=_actorHId(actorhandle)) + + [ Whitespace.NL ] + + saveIdStmts + + self.invokeRecvHandler(md) + + self.makeReply(md, errfnRecv, idvar) + + self.genVerifyMessage(md.decl.type.verify, md.returns, errfnRecv, + self.replyvar) + + [ Whitespace.NL, + StmtReturn(_Result.Processed) ]) + + return lbl, case + + + def genDtorRecvCase(self, md): + lbl = CaseLabel(md.pqMsgId()) + case = StmtBlock() + + stmts = self.deserializeMessage(md, self.side, errfnRecv) + + idvar, saveIdStmts = self.saveActorId(md) + case.addstmts( + stmts + + self.transition(md, 'in') + + [ StmtDecl(Decl(r.bareType(self.side), r.var().name)) + for r in md.returns ] + + self.invokeRecvHandler(md, implicit=0) + + [ Whitespace.NL ] + + saveIdStmts + + self.makeReply(md, errfnRecv, routingId=idvar) + + [ Whitespace.NL ] + + self.genVerifyMessage(md.decl.type.verify, md.returns, errfnRecv, + self.replyvar) + + self.dtorEpilogue(md, md.actorDecl().var()) + + [ Whitespace.NL, + StmtReturn(_Result.Processed) ]) + + return lbl, case + + + def genRecvCase(self, md): + lbl = CaseLabel(md.pqMsgId()) + case = StmtBlock() + + stmts = self.deserializeMessage(md, self.side, errfn=errfnRecv) + + idvar, saveIdStmts = self.saveActorId(md) + case.addstmts( + stmts + + self.transition(md, 'in') + + [ StmtDecl(Decl(r.bareType(self.side), r.var().name)) + for r in md.returns ] + + saveIdStmts + + self.invokeRecvHandler(md) + + [ Whitespace.NL ] + + self.makeReply(md, errfnRecv, routingId=idvar) + + self.genVerifyMessage(md.decl.type.verify, md.returns, errfnRecv, + self.replyvar) + + [ StmtReturn(_Result.Processed) ]) + + return lbl, case + + + # helper methods + + def failIfNullActor(self, actorExpr, retOnNull=ExprLiteral.FALSE, msg=None): + failif = StmtIf(ExprNot(actorExpr)) + if msg: + failif.addifstmt(_printWarningMessage(msg)) + failif.addifstmt(StmtReturn(retOnNull)) + return failif + + def unregisterActor(self): + return [ StmtExpr(ExprCall(self.protocol.unregisterMethod(), + args=[ _actorId() ])), + StmtExpr(ExprCall(ExprVar('SetId'), args=[_FREED_ACTOR_ID])) ] + + def makeMessage(self, md, errfn, fromActor=None): + msgvar = self.msgvar + routingId = self.protocol.routingId(fromActor) + this = None + if md.decl.type.isDtor(): this = md.actorDecl().var() + + stmts = ([ StmtDecl(Decl(Type('IPC::Message', ptr=1), msgvar.name), + init=ExprCall(ExprVar(md.pqMsgCtorFunc()), + args=[ routingId ])) ] + + [ Whitespace.NL ] + + [ self.checkedWrite(p.ipdltype, p.var(), msgvar, sentinelKey=p.name, this=this) + for p in md.params ] + + [ Whitespace.NL ] + + self.setMessageFlags(md, msgvar, reply=0)) + return msgvar, stmts + + + def makeReply(self, md, errfn, routingId): + if routingId is None: + routingId = self.protocol.routingId() + # TODO special cases for async ctor/dtor replies + if not md.decl.type.hasReply(): + return [ ] + + replyvar = self.replyvar + return ( + [ StmtExpr(ExprAssn( + replyvar, ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[ routingId ]))), + Whitespace.NL ] + + [ self.checkedWrite(r.ipdltype, r.var(), replyvar, sentinelKey=r.name) + for r in md.returns ] + + self.setMessageFlags(md, replyvar, reply=1) + + [ self.logMessage(md, replyvar, 'Sending reply ') ]) + + def genVerifyMessage(self, verify, params, errfn, msgsrcVar): + stmts = [ ] + if not verify: + return stmts + if len(params) == 0: + return stmts + + msgvar = ExprVar('msgverify__') + side = self.side + + msgexpr = ExprAddrOf(msgvar) + itervar = ExprVar('msgverifyIter__') + # IPC::Message msgverify__ = Move(*(reply__)); or + # IPC::Message msgverify__ = Move(*(msg__)); + stmts.append(StmtDecl(Decl(Type('IPC::Message', ptr=0), 'msgverify__'), + init=ExprMove(ExprDeref(msgsrcVar)))) + + stmts.extend(( + # PickleIterator msgverifyIter__ = PickleIterator(msgverify__); + [ StmtDecl(Decl(_iterType(ptr=0), itervar.name), + init=ExprCall(ExprVar('PickleIterator'), + args=[ msgvar ])) ] + # declare varCopy for each variable to deserialize. + + [ StmtDecl(Decl(p.bareType(side), p.var().name + 'Copy')) + for p in params ] + + [ Whitespace.NL ] + # checked Read(&(varCopy), &(msgverify__), &(msgverifyIter__)) + + [ self.checkedRead(p.ipdltype, + ExprAddrOf(ExprVar(p.var().name + 'Copy')), + msgexpr, ExprAddrOf(itervar), + errfn, p.bareType(side).name, + p.name) + for p in params ] + + [ self.endRead(msgvar, itervar) ] + # Move the message back to its source before sending. + + [ StmtExpr(ExprAssn(ExprDeref(msgsrcVar), ExprMove(msgvar))) ] + )) + + return stmts + + def setMessageFlags(self, md, var, reply): + stmts = [ ] + + if md.decl.type.isSync(): + stmts.append(StmtExpr(ExprCall( + ExprSelect(var, '->', 'set_sync')))) + elif md.decl.type.isInterrupt(): + stmts.append(StmtExpr(ExprCall( + ExprSelect(var, '->', 'set_interrupt')))) + + if reply: + stmts.append(StmtExpr(ExprCall( + ExprSelect(var, '->', 'set_reply')))) + + return stmts + [ Whitespace.NL ] + + + def deserializeMessage(self, md, side, errfn): + msgvar = self.msgvar + itervar = self.itervar + msgexpr = ExprAddrOf(msgvar) + isctor = md.decl.type.isCtor() + stmts = ([ + self.logMessage(md, msgexpr, 'Received ', + receiving=True), + self.profilerLabel(md), + Whitespace.NL + ]) + + if 0 == len(md.params): + return stmts + + start, decls, reads = 0, [], [] + if isctor: + # return the raw actor handle so that its ID can be used + # to construct the "real" actor + handlevar = self.handlevar + handletype = Type('ActorHandle') + decls = [ StmtDecl(Decl(handletype, handlevar.name)) ] + reads = [ self.checkedRead(None, ExprAddrOf(handlevar), msgexpr, + ExprAddrOf(self.itervar), + errfn, "'%s'" % handletype.name, + sentinelKey='actor') ] + start = 1 + + stmts.extend(( + [ StmtDecl(Decl(_iterType(ptr=0), self.itervar.name), + init=ExprCall(ExprVar('PickleIterator'), + args=[ msgvar ])) ] + + decls + [ StmtDecl(Decl(p.bareType(side), p.var().name)) + for p in md.params ] + + [ Whitespace.NL ] + + reads + [ self.checkedRead(p.ipdltype, ExprAddrOf(p.var()), + msgexpr, ExprAddrOf(itervar), + errfn, "'%s'" % p.bareType(side).name, + sentinelKey=p.name) + for p in md.params[start:] ] + + [ self.endRead(msgvar, itervar) ])) + + return stmts + + + def deserializeReply(self, md, replyexpr, side, errfn, actor=None): + stmts = [ Whitespace.NL, + self.logMessage(md, replyexpr, + 'Received reply ', actor, receiving=True) ] + if 0 == len(md.returns): + return stmts + + itervar = self.itervar + stmts.extend( + [ Whitespace.NL, + StmtDecl(Decl(_iterType(ptr=0), itervar.name), + init=ExprCall(ExprVar('PickleIterator'), + args=[ self.replyvar ])) ] + + [ self.checkedRead(r.ipdltype, r.var(), + ExprAddrOf(self.replyvar), + ExprAddrOf(self.itervar), + errfn, "'%s'" % r.bareType(side).name, + sentinelKey=r.name) + for r in md.returns ] + + [ self.endRead(self.replyvar, itervar) ]) + + return stmts + + def sendAsync(self, md, msgexpr, actor=None): + sendok = ExprVar('sendok__') + return ( + sendok, + ([ Whitespace.NL, + self.logMessage(md, msgexpr, 'Sending ', actor), + self.profilerLabel(md) ] + + self.transition(md, 'out', actor) + + [ Whitespace.NL, + StmtDecl(Decl(Type.BOOL, sendok.name), + init=ExprCall( + ExprSelect(self.protocol.callGetChannel(actor), + '->', 'Send'), + args=[ msgexpr ])) + ]) + ) + + def sendBlocking(self, md, msgexpr, replyexpr, actor=None): + sendok = ExprVar('sendok__') + return ( + sendok, + ([ Whitespace.NL, + self.logMessage(md, msgexpr, 'Sending ', actor), + self.profilerLabel(md) ] + + self.transition(md, 'out', actor) + + [ Whitespace.NL, + StmtDecl( + Decl(Type.BOOL, sendok.name), + init=ExprCall(ExprSelect(self.protocol.callGetChannel(actor), + '->', + _sendPrefix(md.decl.type)), + args=[ msgexpr, ExprAddrOf(replyexpr) ])) + ]) + ) + + def callAllocActor(self, md, retsems, side): + return ExprCall( + _allocMethod(md.decl.type.constructedType(), side), + args=md.makeCxxArgs(retsems=retsems, retcallsems='out', + implicit=0)) + + def callActorDestroy(self, actorexpr, why=_DestroyReason.Deletion): + return ExprCall(ExprSelect(actorexpr, '->', 'DestroySubtree'), + args=[ why ]) + + def callRemoveActor(self, actorexpr, manager=None, ipdltype=None): + if ipdltype is None: ipdltype = self.protocol.decl.type + + if not ipdltype.isManaged(): + return Whitespace('// unmanaged protocol') + + removefunc = self.protocol.removeManageeMethod() + if manager is not None: + removefunc = ExprSelect(manager, '->', removefunc.name) + + return ExprCall(removefunc, + args=[ _protocolId(ipdltype), + actorexpr ]) + + def callDeallocSubtree(self, md, actorexpr): + return ExprCall(ExprSelect(actorexpr, '->', 'DeallocSubtree')) + + def invokeRecvHandler(self, md, implicit=1): + failif = StmtIf(ExprNot( + ExprCall(md.recvMethod(), + args=md.makeCxxArgs(paramsems='move', retsems='in', + retcallsems='out', + implicit=implicit)))) + failif.addifstmts([ + _protocolErrorBreakpoint('Handler returned error code!'), + StmtReturn(_Result.ProcessingError) + ]) + return [ failif ] + + def makeDtorMethodDecl(self, md): + decl = self.makeSendMethodDecl(md) + decl.static = 1 + return decl + + def makeSendMethodDecl(self, md): + implicit = md.decl.type.hasImplicitActorParam() + decl = MethodDecl( + md.sendMethod().name, + params=md.makeCxxParams(paramsems='in', returnsems='out', + side=self.side, implicit=implicit), + warn_unused=(self.side == 'parent'), + ret=Type.BOOL) + if md.decl.type.isCtor(): + decl.ret = md.actorDecl().bareType(self.side) + return decl + + def logMessage(self, md, msgptr, pfx, actor=None, receiving=False): + actorname = _actorName(self.protocol.name, self.side) + + return _ifLogging(ExprLiteral.String(actorname), + [ StmtExpr(ExprCall( + ExprVar('mozilla::ipc::LogMessageForProtocol'), + args=[ ExprLiteral.String(actorname), + self.protocol.callOtherPid(actor), + ExprLiteral.String(pfx), + ExprCall(ExprSelect(msgptr, '->', 'type')), + ExprVar('mozilla::ipc::MessageDirection::eReceiving' + if receiving + else 'mozilla::ipc::MessageDirection::eSending') ])) ]) + + def profilerLabel(self, md): + return StmtExpr(ExprCall(ExprVar('PROFILER_LABEL'), + [ ExprLiteral.String(self.protocol.name), + ExprLiteral.String(md.prettyMsgName()), + ExprVar('js::ProfileEntry::Category::OTHER') ])) + + def saveActorId(self, md): + idvar = ExprVar('id__') + if md.decl.type.hasReply(): + # only save the ID if we're actually going to use it, to + # avoid unused-variable warnings + saveIdStmts = [ StmtDecl(Decl(_actorIdType(), idvar.name), + self.protocol.routingId()) ] + else: + saveIdStmts = [ ] + return idvar, saveIdStmts + + def transition(self, md, direction, actor=None, reply=False): + if actor is not None: stateexpr = _actorState(actor) + else: stateexpr = self.protocol.stateVar() + + if (self.side is 'parent' and direction is 'out' + or self.side is 'child' and direction is 'in'): + action = ExprVar('Trigger::Send') + elif (self.side is 'parent' and direction is 'in' + or self.side is 'child' and direction is 'out'): + action = ExprVar('Trigger::Recv') + else: assert 0 and 'unknown combo %s/%s'% (self.side, direction) + + msgid = md.pqMsgId() if not reply else md.pqReplyId() + ifbad = StmtIf(ExprNot( + ExprCall( + ExprVar(self.protocol.name +'::Transition'), + args=[ ExprCall(ExprVar('Trigger'), + args=[ action, ExprVar(msgid) ]), + ExprAddrOf(stateexpr) ]))) + ifbad.addifstmts(_badTransition()) + return [ ifbad ] + + def checkedRead(self, ipdltype, expr, msgexpr, iterexpr, errfn, paramtype, sentinelKey, sentinel=True): + ifbad = StmtIf(ExprNot(self.read(ipdltype, expr, msgexpr, iterexpr))) + if isinstance(paramtype, list): + errorcall = errfn(*paramtype) + else: + errorcall = errfn('Error deserializing ' + paramtype) + ifbad.addifstmts(errorcall) + + block = Block() + block.addstmt(ifbad) + + if sentinel: + assert sentinelKey + + block.addstmt(Whitespace('// Sentinel = ' + repr(sentinelKey) + '\n', indent=1)) + read = ExprCall(ExprSelect(msgexpr, '->', 'ReadSentinel'), + args=[ iterexpr, ExprLiteral.Int(hashfunc(sentinelKey)) ]) + ifsentinel = StmtIf(ExprNot(read)) + ifsentinel.addifstmts(errorcall) + block.addstmt(ifsentinel) + + return block + + def endRead(self, msgexpr, iterexpr): + return StmtExpr(ExprCall(ExprSelect(msgexpr, '.', 'EndRead'), + args=[ iterexpr ])) + +class _GenerateProtocolParentCode(_GenerateProtocolActorCode): + def __init__(self): + _GenerateProtocolActorCode.__init__(self, 'parent') + + def sendsMessage(self, md): + return not md.decl.type.isIn() + + def receivesMessage(self, md): + return md.decl.type.isInout() or md.decl.type.isIn() + +class _GenerateProtocolChildCode(_GenerateProtocolActorCode): + def __init__(self): + _GenerateProtocolActorCode.__init__(self, 'child') + + def sendsMessage(self, md): + return not md.decl.type.isOut() + + def receivesMessage(self, md): + return md.decl.type.isInout() or md.decl.type.isOut() + + +##----------------------------------------------------------------------------- +## Utility passes +## + +def _splitClassDeclDefn(cls): + """Destructively split |cls| methods into declarations and +definitions (if |not methodDecl.force_inline|). Return classDecl, +methodDefns.""" + defns = Block() + + for i, stmt in enumerate(cls.stmts): + if isinstance(stmt, MethodDefn) and not stmt.decl.force_inline: + decl, defn = _splitMethodDefn(stmt, cls.name) + cls.stmts[i] = StmtDecl(decl) + defns.addstmts([ defn, Whitespace.NL ]) + + return cls, defns + +def _splitMethodDefn(md, clsname): + saveddecl = deepcopy(md.decl) + md.decl.name = (clsname +'::'+ md.decl.name) + md.decl.virtual = 0 + md.decl.static = 0 + md.decl.warn_unused = 0 + md.decl.never_inline = 0 + md.decl.only_for_definition = True + for param in md.decl.params: + if isinstance(param, Param): + param.default = None + return saveddecl, md + + +def _splitFuncDeclDefn(fun): + assert not fun.decl.inline + return StmtDecl(fun.decl), fun + + +# XXX this is tantalizingly similar to _splitClassDeclDefn, but just +# different enough that I don't see the need to define +# _GenerateSkeleton in terms of that +class _GenerateSkeletonImpl(Visitor): + def __init__(self, name, namespaces): + self.name = name + self.cls = None + self.namespaces = namespaces + self.methodimpls = Block() + + def fromclass(self, cls): + cls.accept(self) + + nsclass = _putInNamespaces(self.cls, self.namespaces) + nsmethodimpls = _putInNamespaces(self.methodimpls, self.namespaces) + + return [ + Whitespace(''' +//----------------------------------------------------------------------------- +// Skeleton implementation of abstract actor class + +'''), + Whitespace('// Header file contents\n'), + nsclass, + Whitespace.NL, + Whitespace('\n// C++ file contents\n'), + nsmethodimpls + ] + + + def visitClass(self, cls): + self.cls = Class(self.name, inherits=[ Inherit(Type(cls.name)) ]) + Visitor.visitClass(self, cls) + + def visitMethodDecl(self, md): + if not md.pure: + return + decl = deepcopy(md) + decl.pure = 0 + impl = MethodDefn(MethodDecl(self.implname(md.name), + params=md.params, + ret=md.ret)) + if md.ret.ptr: + impl.addstmt(StmtReturn(ExprLiteral.ZERO)) + elif md.ret == Type.BOOL: + impl.addstmt(StmtReturn(ExprVar('false'))) + + self.cls.addstmts([ StmtDecl(decl), Whitespace.NL ]) + self.addmethodimpl(impl) + + def visitConstructorDecl(self, cd): + self.cls.addstmt(StmtDecl(ConstructorDecl(self.name))) + ctor = ConstructorDefn(ConstructorDecl(self.implname(self.name))) + ctor.addstmt(StmtExpr(ExprCall(ExprVar( 'MOZ_COUNT_CTOR'), + [ ExprVar(self.name) ]))) + self.addmethodimpl(ctor) + + def visitDestructorDecl(self, dd): + self.cls.addstmt( + StmtDecl(DestructorDecl(self.name, virtual=1))) + # FIXME/cjones: hack! + dtor = DestructorDefn(ConstructorDecl(self.implname('~' +self.name))) + dtor.addstmt(StmtExpr(ExprCall(ExprVar( 'MOZ_COUNT_DTOR'), + [ ExprVar(self.name) ]))) + self.addmethodimpl(dtor) + + def addmethodimpl(self, impl): + self.methodimpls.addstmts([ impl, Whitespace.NL ]) + + def implname(self, method): + return self.name +'::'+ method diff --git a/ipc/ipdl/ipdl/parser.py b/ipc/ipdl/ipdl/parser.py new file mode 100644 index 000000000..38c46dc73 --- /dev/null +++ b/ipc/ipdl/ipdl/parser.py @@ -0,0 +1,807 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import os, sys +from ply import lex, yacc + +from ipdl.ast import * + +def _getcallerpath(): + '''Return the absolute path of the file containing the code that +**CALLED** this function.''' + return os.path.abspath(sys._getframe(1).f_code.co_filename) + +##----------------------------------------------------------------------------- + +class ParseError(Exception): + def __init__(self, loc, fmt, *args): + self.loc = loc + self.error = ('%s%s: error: %s'% ( + Parser.includeStackString(), loc, fmt)) % args + def __str__(self): + return self.error + +def _safeLinenoValue(t): + lineno, value = 0, '???' + if hasattr(t, 'lineno'): lineno = t.lineno + if hasattr(t, 'value'): value = t.value + return lineno, value + +def _error(loc, fmt, *args): + raise ParseError(loc, fmt, *args) + +class Parser: + # when we reach an |include [protocol] foo;| statement, we need to + # save the current parser state and create a new one. this "stack" is + # where that state is saved + # + # there is one Parser per file + current = None + parseStack = [ ] + parsed = { } + + def __init__(self, type, name, debug=0): + assert type and name + self.type = type + self.debug = debug + self.filename = None + self.includedirs = None + self.loc = None # not always up to date + self.lexer = None + self.parser = None + self.tu = TranslationUnit(type, name) + self.direction = None + self.errout = None + + def parse(self, input, filename, includedirs, errout): + assert os.path.isabs(filename) + + if filename in Parser.parsed: + return Parser.parsed[filename].tu + + self.lexer = lex.lex(debug=self.debug, + optimize=not self.debug, + lextab="ipdl_lextab") + self.parser = yacc.yacc(debug=self.debug, + optimize=not self.debug, + tabmodule="ipdl_yacctab") + self.filename = filename + self.includedirs = includedirs + self.tu.filename = filename + self.errout = errout + + Parser.parsed[filename] = self + Parser.parseStack.append(Parser.current) + Parser.current = self + + try: + ast = self.parser.parse(input=input, lexer=self.lexer, + debug=self.debug) + except ParseError, p: + print >>errout, p + return None + + Parser.current = Parser.parseStack.pop() + return ast + + def resolveIncludePath(self, filepath): + '''Return the absolute path from which the possibly partial +|filepath| should be read, or |None| if |filepath| cannot be located.''' + for incdir in self.includedirs +[ '' ]: + realpath = os.path.join(incdir, filepath) + if os.path.isfile(realpath): + return os.path.abspath(realpath) + return None + + # returns a GCC-style string representation of the include stack. + # e.g., + # in file included from 'foo.ipdl', line 120: + # in file included from 'bar.ipd', line 12: + # which can be printed above a proper error message or warning + @staticmethod + def includeStackString(): + s = '' + for parse in Parser.parseStack[1:]: + s += " in file included from `%s', line %d:\n"% ( + parse.loc.filename, parse.loc.lineno) + return s + +def locFromTok(p, num): + return Loc(Parser.current.filename, p.lineno(num)) + + +##----------------------------------------------------------------------------- + +reserved = set(( + 'answer', + 'as', + 'async', + 'both', + 'bridges', + 'call', + 'child', + 'class', + 'compress', + 'compressall', + '__delete__', + 'delete', # reserve 'delete' to prevent its use + 'from', + 'goto', + 'include', + 'intr', + 'manager', + 'manages', + 'namespace', + 'nested', + 'nullable', + 'opens', + 'or', + 'parent', + 'prio', + 'protocol', + 'recv', + 'returns', + 'send', + 'spawns', + 'start', + 'state', + 'struct', + 'sync', + 'union', + 'upto', + 'using', + 'verify')) +tokens = [ + 'COLONCOLON', 'ID', 'STRING', +] + [ r.upper() for r in reserved ] + +t_COLONCOLON = '::' + +literals = '(){}[]<>;:,~' +t_ignore = ' \f\t\v' + +def t_linecomment(t): + r'//[^\n]*' + +def t_multilinecomment(t): + r'/\*(\n|.)*?\*/' + t.lexer.lineno += t.value.count('\n') + +def t_NL(t): + r'(?:\r\n|\n|\n)+' + t.lexer.lineno += len(t.value) + +def t_ID(t): + r'[a-zA-Z_][a-zA-Z0-9_]*' + if t.value in reserved: + t.type = t.value.upper() + return t + +def t_STRING(t): + r'"[^"\n]*"' + t.value = t.value[1:-1] + return t + +def t_error(t): + _error(Loc(Parser.current.filename, t.lineno), + 'lexically invalid characters `%s', t.value) + +##----------------------------------------------------------------------------- + +def p_TranslationUnit(p): + """TranslationUnit : Preamble NamespacedStuff""" + tu = Parser.current.tu + tu.loc = Loc(tu.filename) + for stmt in p[1]: + if isinstance(stmt, CxxInclude): + tu.addCxxInclude(stmt) + elif isinstance(stmt, Include): + tu.addInclude(stmt) + elif isinstance(stmt, UsingStmt): + tu.addUsingStmt(stmt) + else: + assert 0 + + for thing in p[2]: + if isinstance(thing, StructDecl): + tu.addStructDecl(thing) + elif isinstance(thing, UnionDecl): + tu.addUnionDecl(thing) + elif isinstance(thing, Protocol): + if tu.protocol is not None: + _error(thing.loc, "only one protocol definition per file") + tu.protocol = thing + else: + assert(0) + + # The "canonical" namespace of the tu, what it's considered to be + # in for the purposes of C++: |#include "foo/bar/TU.h"| + if tu.protocol: + assert tu.filetype == 'protocol' + tu.namespaces = tu.protocol.namespaces + tu.name = tu.protocol.name + else: + assert tu.filetype == 'header' + # There's not really a canonical "thing" in headers. So + # somewhat arbitrarily use the namespace of the last + # interesting thing that was declared. + for thing in reversed(tu.structsAndUnions): + tu.namespaces = thing.namespaces + break + + p[0] = tu + +##-------------------- +## Preamble +def p_Preamble(p): + """Preamble : Preamble PreambleStmt ';' + |""" + if 1 == len(p): + p[0] = [ ] + else: + p[1].append(p[2]) + p[0] = p[1] + +def p_PreambleStmt(p): + """PreambleStmt : CxxIncludeStmt + | IncludeStmt + | UsingStmt""" + p[0] = p[1] + +def p_CxxIncludeStmt(p): + """CxxIncludeStmt : INCLUDE STRING""" + p[0] = CxxInclude(locFromTok(p, 1), p[2]) + +def p_IncludeStmt(p): + """IncludeStmt : INCLUDE PROTOCOL ID + | INCLUDE ID""" + loc = locFromTok(p, 1) + + Parser.current.loc = loc + if 4 == len(p): + id = p[3] + type = 'protocol' + else: + id = p[2] + type = 'header' + inc = Include(loc, type, id) + + path = Parser.current.resolveIncludePath(inc.file) + if path is None: + raise ParseError(loc, "can't locate include file `%s'"% ( + inc.file)) + + inc.tu = Parser(type, id).parse(open(path).read(), path, Parser.current.includedirs, Parser.current.errout) + p[0] = inc + +def p_UsingStmt(p): + """UsingStmt : USING CxxType FROM STRING + | USING CLASS CxxType FROM STRING + | USING STRUCT CxxType FROM STRING""" + if 6 == len(p): + header = p[5] + elif 5 == len(p): + header = p[4] + else: + header = None + if 6 == len(p): + kind = p[2] + else: + kind = None + if 6 == len(p): + cxxtype = p[3] + else: + cxxtype = p[2] + p[0] = UsingStmt(locFromTok(p, 1), cxxtype, header, kind) + +##-------------------- +## Namespaced stuff +def p_NamespacedStuff(p): + """NamespacedStuff : NamespacedStuff NamespaceThing + | NamespaceThing""" + if 2 == len(p): + p[0] = p[1] + else: + p[1].extend(p[2]) + p[0] = p[1] + +def p_NamespaceThing(p): + """NamespaceThing : NAMESPACE ID '{' NamespacedStuff '}' + | StructDecl + | UnionDecl + | ProtocolDefn""" + if 2 == len(p): + p[0] = [ p[1] ] + else: + for thing in p[4]: + thing.addOuterNamespace(Namespace(locFromTok(p, 1), p[2])) + p[0] = p[4] + +def p_StructDecl(p): + """StructDecl : STRUCT ID '{' StructFields '}' ';' + | STRUCT ID '{' '}' ';'""" + if 7 == len(p): + p[0] = StructDecl(locFromTok(p, 1), p[2], p[4]) + else: + p[0] = StructDecl(locFromTok(p, 1), p[2], [ ]) + +def p_StructFields(p): + """StructFields : StructFields StructField ';' + | StructField ';'""" + if 3 == len(p): + p[0] = [ p[1] ] + else: + p[1].append(p[2]) + p[0] = p[1] + +def p_StructField(p): + """StructField : Type ID""" + p[0] = StructField(locFromTok(p, 1), p[1], p[2]) + +def p_UnionDecl(p): + """UnionDecl : UNION ID '{' ComponentTypes '}' ';'""" + p[0] = UnionDecl(locFromTok(p, 1), p[2], p[4]) + +def p_ComponentTypes(p): + """ComponentTypes : ComponentTypes Type ';' + | Type ';'""" + if 3 == len(p): + p[0] = [ p[1] ] + else: + p[1].append(p[2]) + p[0] = p[1] + +def p_ProtocolDefn(p): + """ProtocolDefn : OptionalProtocolSendSemanticsQual PROTOCOL ID '{' ProtocolBody '}' ';'""" + protocol = p[5] + protocol.loc = locFromTok(p, 2) + protocol.name = p[3] + protocol.nestedRange = p[1][0] + protocol.sendSemantics = p[1][1] + p[0] = protocol + + if Parser.current.type == 'header': + _error(protocol.loc, 'can\'t define a protocol in a header. Do it in a protocol spec instead.') + + +def p_ProtocolBody(p): + """ProtocolBody : SpawnsStmtsOpt""" + p[0] = p[1] + +##-------------------- +## spawns/bridges/opens stmts + +def p_SpawnsStmtsOpt(p): + """SpawnsStmtsOpt : SpawnsStmt SpawnsStmtsOpt + | BridgesStmtsOpt""" + if 2 == len(p): + p[0] = p[1] + else: + p[2].spawnsStmts.insert(0, p[1]) + p[0] = p[2] + +def p_SpawnsStmt(p): + """SpawnsStmt : PARENT SPAWNS ID AsOpt ';' + | CHILD SPAWNS ID AsOpt ';'""" + p[0] = SpawnsStmt(locFromTok(p, 1), p[1], p[3], p[4]) + +def p_AsOpt(p): + """AsOpt : AS PARENT + | AS CHILD + | """ + if 3 == len(p): + p[0] = p[2] + else: + p[0] = 'child' + +def p_BridgesStmtsOpt(p): + """BridgesStmtsOpt : BridgesStmt BridgesStmtsOpt + | OpensStmtsOpt""" + if 2 == len(p): + p[0] = p[1] + else: + p[2].bridgesStmts.insert(0, p[1]) + p[0] = p[2] + +def p_BridgesStmt(p): + """BridgesStmt : BRIDGES ID ',' ID ';'""" + p[0] = BridgesStmt(locFromTok(p, 1), p[2], p[4]) + +def p_OpensStmtsOpt(p): + """OpensStmtsOpt : OpensStmt OpensStmtsOpt + | ManagersStmtOpt""" + if 2 == len(p): + p[0] = p[1] + else: + p[2].opensStmts.insert(0, p[1]) + p[0] = p[2] + +def p_OpensStmt(p): + """OpensStmt : PARENT OPENS ID ';' + | CHILD OPENS ID ';'""" + p[0] = OpensStmt(locFromTok(p, 1), p[1], p[3]) + +##-------------------- +## manager/manages stmts + +def p_ManagersStmtOpt(p): + """ManagersStmtOpt : ManagersStmt ManagesStmtsOpt + | ManagesStmtsOpt""" + if 2 == len(p): + p[0] = p[1] + else: + p[2].managers = p[1] + p[0] = p[2] + +def p_ManagersStmt(p): + """ManagersStmt : MANAGER ManagerList ';'""" + if 1 == len(p): + p[0] = [ ] + else: + p[0] = p[2] + +def p_ManagerList(p): + """ManagerList : ID + | ManagerList OR ID""" + if 2 == len(p): + p[0] = [ Manager(locFromTok(p, 1), p[1]) ] + else: + p[1].append(Manager(locFromTok(p, 3), p[3])) + p[0] = p[1] + +def p_ManagesStmtsOpt(p): + """ManagesStmtsOpt : ManagesStmt ManagesStmtsOpt + | MessageDeclsOpt""" + if 2 == len(p): + p[0] = p[1] + else: + p[2].managesStmts.insert(0, p[1]) + p[0] = p[2] + +def p_ManagesStmt(p): + """ManagesStmt : MANAGES ID ';'""" + p[0] = ManagesStmt(locFromTok(p, 1), p[2]) + + +##-------------------- +## Message decls + +def p_MessageDeclsOpt(p): + """MessageDeclsOpt : MessageDeclThing MessageDeclsOpt + | TransitionStmtsOpt""" + if 2 == len(p): + p[0] = p[1] + else: + p[2].messageDecls.insert(0, p[1]) + p[0] = p[2] + +def p_MessageDeclThing(p): + """MessageDeclThing : MessageDirectionLabel ':' MessageDecl ';' + | MessageDecl ';'""" + if 3 == len(p): + p[0] = p[1] + else: + p[0] = p[3] + +def p_MessageDirectionLabel(p): + """MessageDirectionLabel : PARENT + | CHILD + | BOTH""" + if p[1] == 'parent': + Parser.current.direction = IN + elif p[1] == 'child': + Parser.current.direction = OUT + elif p[1] == 'both': + Parser.current.direction = INOUT + else: + assert 0 + +def p_MessageDecl(p): + """MessageDecl : SendSemanticsQual MessageBody""" + msg = p[2] + msg.nested = p[1][0] + msg.prio = p[1][1] + msg.sendSemantics = p[1][2] + + if Parser.current.direction is None: + _error(msg.loc, 'missing message direction') + msg.direction = Parser.current.direction + + p[0] = msg + +def p_MessageBody(p): + """MessageBody : MessageId MessageInParams MessageOutParams OptionalMessageModifiers""" + # FIXME/cjones: need better loc info: use one of the quals + loc, name = p[1] + msg = MessageDecl(loc) + msg.name = name + msg.addInParams(p[2]) + msg.addOutParams(p[3]) + msg.addModifiers(p[4]) + + p[0] = msg + +def p_MessageId(p): + """MessageId : ID + | __DELETE__ + | DELETE + | '~' ID""" + loc = locFromTok(p, 1) + if 3 == len(p): + _error(loc, "sorry, `%s()' destructor syntax is a relic from a bygone era. Declare `__delete__()' in the `%s' protocol instead", p[1]+p[2], p[2]) + elif 'delete' == p[1]: + _error(loc, "`delete' is a reserved identifier") + p[0] = [ loc, p[1] ] + +def p_MessageInParams(p): + """MessageInParams : '(' ParamList ')'""" + p[0] = p[2] + +def p_MessageOutParams(p): + """MessageOutParams : RETURNS '(' ParamList ')' + | """ + if 1 == len(p): + p[0] = [ ] + else: + p[0] = p[3] + +def p_OptionalMessageModifiers(p): + """OptionalMessageModifiers : OptionalMessageModifiers MessageModifier + | MessageModifier + | """ + if 1 == len(p): + p[0] = [ ] + elif 2 == len(p): + p[0] = [ p[1] ] + else: + p[1].append(p[2]) + p[0] = p[1] + +def p_MessageModifier(p): + """ MessageModifier : MessageVerify + | MessageCompress """ + p[0] = p[1] + +def p_MessageVerify(p): + """MessageVerify : VERIFY""" + p[0] = p[1] + +def p_MessageCompress(p): + """MessageCompress : COMPRESS + | COMPRESSALL""" + p[0] = p[1] + +##-------------------- +## State machine + +def p_TransitionStmtsOpt(p): + """TransitionStmtsOpt : TransitionStmt TransitionStmtsOpt + |""" + if 1 == len(p): + # we fill in |loc| in the Protocol rule + p[0] = Protocol(None) + else: + p[2].transitionStmts.insert(0, p[1]) + p[0] = p[2] + +def p_TransitionStmt(p): + """TransitionStmt : OptionalStart STATE State ':' Transitions""" + p[3].start = p[1] + p[0] = TransitionStmt(locFromTok(p, 2), p[3], p[5]) + +def p_OptionalStart(p): + """OptionalStart : START + | """ + p[0] = (len(p) == 2) # True iff 'start' specified + +def p_Transitions(p): + """Transitions : Transitions Transition + | Transition""" + if 3 == len(p): + p[1].append(p[2]) + p[0] = p[1] + else: + p[0] = [ p[1] ] + +def p_Transition(p): + """Transition : Trigger ID GOTO StateList ';' + | Trigger __DELETE__ ';' + | Trigger DELETE ';'""" + if 'delete' == p[2]: + _error(locFromTok(p, 1), "`delete' is a reserved identifier") + + loc, trigger = p[1] + if 6 == len(p): + nextstates = p[4] + else: + nextstates = [ State.DEAD ] + p[0] = Transition(loc, trigger, p[2], nextstates) + +def p_Trigger(p): + """Trigger : SEND + | RECV + | CALL + | ANSWER""" + p[0] = [ locFromTok(p, 1), Transition.nameToTrigger(p[1]) ] + +def p_StateList(p): + """StateList : StateList OR State + | State""" + if 2 == len(p): + p[0] = [ p[1] ] + else: + p[1].append(p[3]) + p[0] = p[1] + +def p_State(p): + """State : ID""" + p[0] = State(locFromTok(p, 1), p[1]) + +##-------------------- +## Minor stuff +def p_Nested(p): + """Nested : ID""" + kinds = {'not': 1, + 'inside_sync': 2, + 'inside_cpow': 3} + if p[1] not in kinds: + _error(locFromTok(p, 1), "Expected not, inside_sync, or inside_cpow for nested()") + + p[0] = { 'nested': kinds[p[1]] } + +def p_Priority(p): + """Priority : ID""" + kinds = {'normal': 1, + 'high': 2} + if p[1] not in kinds: + _error(locFromTok(p, 1), "Expected normal or high for prio()") + + p[0] = { 'prio': kinds[p[1]] } + +def p_SendQualifier(p): + """SendQualifier : NESTED '(' Nested ')' + | PRIO '(' Priority ')'""" + p[0] = p[3] + +def p_SendQualifierList(p): + """SendQualifierList : SendQualifier SendQualifierList + | """ + if len(p) > 1: + p[0] = p[1] + p[0].update(p[2]) + else: + p[0] = {} + +def p_SendSemanticsQual(p): + """SendSemanticsQual : SendQualifierList ASYNC + | SendQualifierList SYNC + | INTR""" + quals = {} + if len(p) == 3: + quals = p[1] + mtype = p[2] + else: + mtype = 'intr' + + if mtype == 'async': mtype = ASYNC + elif mtype == 'sync': mtype = SYNC + elif mtype == 'intr': mtype = INTR + else: assert 0 + + p[0] = [ quals.get('nested', NOT_NESTED), quals.get('prio', NORMAL_PRIORITY), mtype ] + +def p_OptionalProtocolSendSemanticsQual(p): + """OptionalProtocolSendSemanticsQual : ProtocolSendSemanticsQual + | """ + if 2 == len(p): p[0] = p[1] + else: p[0] = [ (NOT_NESTED, NOT_NESTED), ASYNC ] + +def p_ProtocolSendSemanticsQual(p): + """ProtocolSendSemanticsQual : ASYNC + | SYNC + | NESTED '(' UPTO Nested ')' ASYNC + | NESTED '(' UPTO Nested ')' SYNC + | INTR""" + if p[1] == 'nested': + mtype = p[6] + nested = (NOT_NESTED, p[4]) + else: + mtype = p[1] + nested = (NOT_NESTED, NOT_NESTED) + + if mtype == 'async': mtype = ASYNC + elif mtype == 'sync': mtype = SYNC + elif mtype == 'intr': mtype = INTR + else: assert 0 + + p[0] = [ nested, mtype ] + +def p_ParamList(p): + """ParamList : ParamList ',' Param + | Param + | """ + if 1 == len(p): + p[0] = [ ] + elif 2 == len(p): + p[0] = [ p[1] ] + else: + p[1].append(p[3]) + p[0] = p[1] + +def p_Param(p): + """Param : Type ID""" + p[0] = Param(locFromTok(p, 1), p[1], p[2]) + +def p_Type(p): + """Type : MaybeNullable BasicType""" + # only actor types are nullable; we check this in the type checker + p[2].nullable = p[1] + p[0] = p[2] + +def p_BasicType(p): + """BasicType : ScalarType + | ScalarType '[' ']'""" + if 4 == len(p): + p[1].array = 1 + p[0] = p[1] + +def p_ScalarType(p): + """ScalarType : ActorType + | CxxID""" # ID == CxxType; we forbid qnames here, + # in favor of the |using| declaration + if isinstance(p[1], TypeSpec): + p[0] = p[1] + else: + loc, id = p[1] + p[0] = TypeSpec(loc, QualifiedId(loc, id)) + +def p_ActorType(p): + """ActorType : ID ':' State""" + loc = locFromTok(p, 1) + p[0] = TypeSpec(loc, QualifiedId(loc, p[1]), state=p[3]) + +def p_MaybeNullable(p): + """MaybeNullable : NULLABLE + | """ + p[0] = (2 == len(p)) + +##-------------------- +## C++ stuff +def p_CxxType(p): + """CxxType : QualifiedID + | CxxID""" + if isinstance(p[1], QualifiedId): + p[0] = TypeSpec(p[1].loc, p[1]) + else: + loc, id = p[1] + p[0] = TypeSpec(loc, QualifiedId(loc, id)) + +def p_QualifiedID(p): + """QualifiedID : QualifiedID COLONCOLON CxxID + | CxxID COLONCOLON CxxID""" + if isinstance(p[1], QualifiedId): + loc, id = p[3] + p[1].qualify(id) + p[0] = p[1] + else: + loc1, id1 = p[1] + _, id2 = p[3] + p[0] = QualifiedId(loc1, id2, [ id1 ]) + +def p_CxxID(p): + """CxxID : ID + | CxxTemplateInst""" + if isinstance(p[1], tuple): + p[0] = p[1] + else: + p[0] = (locFromTok(p, 1), str(p[1])) + +def p_CxxTemplateInst(p): + """CxxTemplateInst : ID '<' ID '>'""" + p[0] = (locFromTok(p, 1), str(p[1]) +'<'+ str(p[3]) +'>') + +def p_error(t): + lineno, value = _safeLinenoValue(t) + _error(Loc(Parser.current.filename, lineno), + "bad syntax near `%s'", value) diff --git a/ipc/ipdl/ipdl/type.py b/ipc/ipdl/ipdl/type.py new file mode 100644 index 000000000..68a10cd01 --- /dev/null +++ b/ipc/ipdl/ipdl/type.py @@ -0,0 +1,2200 @@ +# vim: set ts=4 sw=4 tw=99 et: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import os, sys + +from ipdl.ast import CxxInclude, Decl, Loc, QualifiedId, State, StructDecl, TransitionStmt +from ipdl.ast import TypeSpec, UnionDecl, UsingStmt, Visitor +from ipdl.ast import ASYNC, SYNC, INTR +from ipdl.ast import IN, OUT, INOUT, ANSWER, CALL, RECV, SEND +from ipdl.ast import NOT_NESTED, INSIDE_SYNC_NESTED, INSIDE_CPOW_NESTED +import ipdl.builtin as builtin + +_DELETE_MSG = '__delete__' + + +def _otherside(side): + if side == 'parent': return 'child' + elif side == 'child': return 'parent' + else: assert 0 and 'unknown side "%s"'% (side) + +def unique_pairs(s): + n = len(s) + for i, e1 in enumerate(s): + for j in xrange(i+1, n): + yield (e1, s[j]) + +def cartesian_product(s1, s2): + for e1 in s1: + for e2 in s2: + yield (e1, e2) + + +class TypeVisitor: + def __init__(self): + self.visited = set() + + def defaultVisit(self, node, *args): + raise Exception, "INTERNAL ERROR: no visitor for node type `%s'"% ( + node.__class__.__name__) + + def visitVoidType(self, v, *args): + pass + + def visitBuiltinCxxType(self, t, *args): + pass + + def visitImportedCxxType(self, t, *args): + pass + + def visitStateType(self, s, *args): + pass + + def visitMessageType(self, m, *args): + for param in m.params: + param.accept(self, *args) + for ret in m.returns: + ret.accept(self, *args) + if m.cdtype is not None: + m.cdtype.accept(self, *args) + + def visitProtocolType(self, p, *args): + # NB: don't visit manager and manages. a naive default impl + # could result in an infinite loop + pass + + def visitActorType(self, a, *args): + a.protocol.accept(self, *args) + a.state.accept(self, *args) + + def visitStructType(self, s, *args): + if s in self.visited: + return + + self.visited.add(s) + for field in s.fields: + field.accept(self, *args) + + def visitUnionType(self, u, *args): + if u in self.visited: + return + + self.visited.add(u) + for component in u.components: + component.accept(self, *args) + + def visitArrayType(self, a, *args): + a.basetype.accept(self, *args) + + def visitShmemType(self, s, *args): + pass + + def visitShmemChmodType(self, c, *args): + c.shmem.accept(self) + + def visitFDType(self, s, *args): + pass + + def visitEndpointType(self, s, *args): + pass + +class Type: + def __cmp__(self, o): + return cmp(self.fullname(), o.fullname()) + def __eq__(self, o): + return (self.__class__ == o.__class__ + and self.fullname() == o.fullname()) + def __hash__(self): + return hash(self.fullname()) + + # Is this a C++ type? + def isCxx(self): + return False + # Is this an IPDL type? + def isIPDL(self): + return False + # Is this type neither compound nor an array? + def isAtom(self): + return False + # Can this type appear in IPDL programs? + def isVisible(self): + return False + def isVoid(self): + return False + def typename(self): + return self.__class__.__name__ + + def name(self): raise Exception, 'NYI' + def fullname(self): raise Exception, 'NYI' + + def accept(self, visitor, *args): + visit = getattr(visitor, 'visit'+ self.__class__.__name__, None) + if visit is None: + return getattr(visitor, 'defaultVisit')(self, *args) + return visit(self, *args) + +class VoidType(Type): + def isCxx(self): + return True + def isIPDL(self): + return False + def isAtom(self): + return True + def isVisible(self): + return False + def isVoid(self): + return True + + def name(self): return 'void' + def fullname(self): return 'void' + +VOID = VoidType() + +##-------------------- +class CxxType(Type): + def isCxx(self): + return True + def isAtom(self): + return True + def isBuiltin(self): + return False + def isImported(self): + return False + def isGenerated(self): + return False + def isVisible(self): + return True + +class BuiltinCxxType(CxxType): + def __init__(self, qname): + assert isinstance(qname, QualifiedId) + self.loc = qname.loc + self.qname = qname + def isBuiltin(self): return True + + def name(self): + return self.qname.baseid + def fullname(self): + return str(self.qname) + +class ImportedCxxType(CxxType): + def __init__(self, qname): + assert isinstance(qname, QualifiedId) + self.loc = qname.loc + self.qname = qname + def isImported(self): return True + + def name(self): + return self.qname.baseid + def fullname(self): + return str(self.qname) + +##-------------------- +class IPDLType(Type): + def isIPDL(self): return True + def isVisible(self): return True + def isState(self): return False + def isMessage(self): return False + def isProtocol(self): return False + def isActor(self): return False + def isStruct(self): return False + def isUnion(self): return False + def isArray(self): return False + def isAtom(self): return True + def isCompound(self): return False + def isShmem(self): return False + def isChmod(self): return False + def isFD(self): return False + def isEndpoint(self): return False + + def isAsync(self): return self.sendSemantics == ASYNC + def isSync(self): return self.sendSemantics == SYNC + def isInterrupt(self): return self.sendSemantics is INTR + + def hasReply(self): return (self.isSync() or self.isInterrupt()) + + @classmethod + def convertsTo(cls, lesser, greater): + if (lesser.nestedRange[0] < greater.nestedRange[0] or + lesser.nestedRange[1] > greater.nestedRange[1]): + return False + + # Protocols that use intr semantics are not allowed to use + # message nesting. + if (greater.isInterrupt() and + lesser.nestedRange != (NOT_NESTED, NOT_NESTED)): + return False + + if lesser.isAsync(): + return True + elif lesser.isSync() and not greater.isAsync(): + return True + elif greater.isInterrupt(): + return True + + return False + + def needsMoreJuiceThan(self, o): + return not IPDLType.convertsTo(self, o) + +class StateType(IPDLType): + def __init__(self, protocol, name, start=False): + self.protocol = protocol + self.name = name + self.start = start + def isState(self): return True + def name(self): + return self.name + def fullname(self): + return self.name() + +class MessageType(IPDLType): + def __init__(self, nested, prio, sendSemantics, direction, + ctor=False, dtor=False, cdtype=None, compress=False, + verify=False): + assert not (ctor and dtor) + assert not (ctor or dtor) or type is not None + + self.nested = nested + self.prio = prio + self.nestedRange = (nested, nested) + self.sendSemantics = sendSemantics + self.direction = direction + self.params = [ ] + self.returns = [ ] + self.ctor = ctor + self.dtor = dtor + self.cdtype = cdtype + self.compress = compress + self.verify = verify + def isMessage(self): return True + + def isCtor(self): return self.ctor + def isDtor(self): return self.dtor + def constructedType(self): return self.cdtype + + def isIn(self): return self.direction is IN + def isOut(self): return self.direction is OUT + def isInout(self): return self.direction is INOUT + + def hasImplicitActorParam(self): + return self.isCtor() or self.isDtor() + +class Bridge: + def __init__(self, parentPtype, childPtype): + assert parentPtype.isToplevel() and childPtype.isToplevel() + self.parent = parentPtype + self.child = childPtype + + def __cmp__(self, o): + return cmp(self.parent, o.parent) or cmp(self.child, o.child) + def __eq__(self, o): + return self.parent == o.parent and self.child == o.child + def __hash__(self): + return hash(self.parent) + hash(self.child) + +class ProtocolType(IPDLType): + def __init__(self, qname, nestedRange, sendSemantics, stateless=False): + self.qname = qname + self.nestedRange = nestedRange + self.sendSemantics = sendSemantics + self.spawns = set() # ProtocolType + self.opens = set() # ProtocolType + self.managers = [] # ProtocolType + self.manages = [ ] + self.stateless = stateless + self.hasDelete = False + self.hasReentrantDelete = False + def isProtocol(self): return True + + def name(self): + return self.qname.baseid + def fullname(self): + return str(self.qname) + + def addManager(self, mgrtype): + assert mgrtype.isIPDL() and mgrtype.isProtocol() + self.managers.append(mgrtype) + + def addSpawn(self, ptype): + assert self.isToplevel() and ptype.isToplevel() + self.spawns.add(ptype) + + def addOpen(self, ptype): + assert self.isToplevel() and ptype.isToplevel() + self.opens.add(ptype) + + def managedBy(self, mgr): + self.managers = list(mgr) + + def toplevel(self): + if self.isToplevel(): + return self + for mgr in self.managers: + if mgr is not self: + return mgr.toplevel() + + def toplevels(self): + if self.isToplevel(): + return [self] + toplevels = list() + for mgr in self.managers: + if mgr is not self: + toplevels.extend(mgr.toplevels()) + return set(toplevels) + + def isManagerOf(self, pt): + for managed in self.manages: + if pt is managed: + return True + return False + def isManagedBy(self, pt): + return pt in self.managers + + def isManager(self): + return len(self.manages) > 0 + def isManaged(self): + return 0 < len(self.managers) + def isToplevel(self): + return not self.isManaged() + + def manager(self): + assert 1 == len(self.managers) + for mgr in self.managers: return mgr + +class ActorType(IPDLType): + def __init__(self, protocol, state=None, nullable=0): + self.protocol = protocol + self.state = state + self.nullable = nullable + def isActor(self): return True + + def name(self): + return self.protocol.name() + def fullname(self): + return self.protocol.fullname() + +class _CompoundType(IPDLType): + def __init__(self): + self.defined = False # bool + self.mutualRec = set() # set(_CompoundType | ArrayType) + def isAtom(self): + return False + def isCompound(self): + return True + def itercomponents(self): + raise Exception('"pure virtual" method') + + def mutuallyRecursiveWith(self, t, exploring=None): + '''|self| is mutually recursive with |t| iff |self| and |t| +are in a cycle in the type graph rooted at |self|. This function +looks for such a cycle and returns True if found.''' + if exploring is None: + exploring = set() + + if t.isAtom(): + return False + elif t is self or t in self.mutualRec: + return True + elif t.isArray(): + isrec = self.mutuallyRecursiveWith(t.basetype, exploring) + if isrec: self.mutualRec.add(t) + return isrec + elif t in exploring: + return False + + exploring.add(t) + for c in t.itercomponents(): + if self.mutuallyRecursiveWith(c, exploring): + self.mutualRec.add(c) + return True + exploring.remove(t) + + return False + +class StructType(_CompoundType): + def __init__(self, qname, fields): + _CompoundType.__init__(self) + self.qname = qname + self.fields = fields # [ Type ] + + def isStruct(self): return True + def itercomponents(self): + for f in self.fields: + yield f + + def name(self): return self.qname.baseid + def fullname(self): return str(self.qname) + +class UnionType(_CompoundType): + def __init__(self, qname, components): + _CompoundType.__init__(self) + self.qname = qname + self.components = components # [ Type ] + + def isUnion(self): return True + def itercomponents(self): + for c in self.components: + yield c + + def name(self): return self.qname.baseid + def fullname(self): return str(self.qname) + +class ArrayType(IPDLType): + def __init__(self, basetype): + self.basetype = basetype + def isAtom(self): return False + def isArray(self): return True + + def name(self): return self.basetype.name() +'[]' + def fullname(self): return self.basetype.fullname() +'[]' + +class ShmemType(IPDLType): + def __init__(self, qname): + self.qname = qname + def isShmem(self): return True + + def name(self): + return self.qname.baseid + def fullname(self): + return str(self.qname) + +class FDType(IPDLType): + def __init__(self, qname): + self.qname = qname + def isFD(self): return True + + def name(self): + return self.qname.baseid + def fullname(self): + return str(self.qname) + +class EndpointType(IPDLType): + def __init__(self, qname): + self.qname = qname + def isEndpoint(self): return True + + def name(self): + return self.qname.baseid + def fullname(self): + return str(self.qname) + +def iteractortypes(t, visited=None): + """Iterate over any actor(s) buried in |type|.""" + if visited is None: + visited = set() + + # XXX |yield| semantics makes it hard to use TypeVisitor + if not t.isIPDL(): + return + elif t.isActor(): + yield t + elif t.isArray(): + for actor in iteractortypes(t.basetype, visited): + yield actor + elif t.isCompound() and t not in visited: + visited.add(t) + for c in t.itercomponents(): + for actor in iteractortypes(c, visited): + yield actor + +def hasactor(type): + """Return true iff |type| is an actor or has one buried within.""" + for _ in iteractortypes(type): return True + return False + +def hasshmem(type): + """Return true iff |type| is shmem or has it buried within.""" + class found: pass + class findShmem(TypeVisitor): + def visitShmemType(self, s): raise found() + try: + type.accept(findShmem()) + except found: + return True + return False + +def hasfd(type): + """Return true iff |type| is fd or has it buried within.""" + class found: pass + class findFD(TypeVisitor): + def visitFDType(self, s): raise found() + try: + type.accept(findFD()) + except found: + return True + return False + +##-------------------- +_builtinloc = Loc('<builtin>', 0) +def makeBuiltinUsing(tname): + quals = tname.split('::') + base = quals.pop() + quals = quals[0:] + return UsingStmt(_builtinloc, + TypeSpec(_builtinloc, + QualifiedId(_builtinloc, base, quals))) + +builtinUsing = [ makeBuiltinUsing(t) for t in builtin.Types ] +builtinHeaderIncludes = [ CxxInclude(_builtinloc, f) for f in builtin.HeaderIncludes ] + +def errormsg(loc, fmt, *args): + while not isinstance(loc, Loc): + if loc is None: loc = Loc.NONE + else: loc = loc.loc + return '%s: error: %s'% (str(loc), fmt % args) + +##-------------------- +class SymbolTable: + def __init__(self, errors): + self.errors = errors + self.scopes = [ { } ] # stack({}) + self.globalScope = self.scopes[0] + self.currentScope = self.globalScope + + def enterScope(self, node): + assert (isinstance(self.scopes[0], dict) + and self.globalScope is self.scopes[0]) + assert (isinstance(self.currentScope, dict)) + + if not hasattr(node, 'symtab'): + node.symtab = { } + + self.scopes.append(node.symtab) + self.currentScope = self.scopes[-1] + + def exitScope(self, node): + symtab = self.scopes.pop() + assert self.currentScope is symtab + + self.currentScope = self.scopes[-1] + + assert (isinstance(self.scopes[0], dict) + and self.globalScope is self.scopes[0]) + assert isinstance(self.currentScope, dict) + + def lookup(self, sym): + # NB: since IPDL doesn't allow any aliased names of different types, + # it doesn't matter in which order we walk the scope chain to resolve + # |sym| + for scope in self.scopes: + decl = scope.get(sym, None) + if decl is not None: return decl + return None + + def declare(self, decl): + assert decl.progname or decl.shortname or decl.fullname + assert decl.loc + assert decl.type + + def tryadd(name): + olddecl = self.lookup(name) + if olddecl is not None: + self.errors.append(errormsg( + decl.loc, + "redeclaration of symbol `%s', first declared at %s", + name, olddecl.loc)) + return + self.currentScope[name] = decl + decl.scope = self.currentScope + + if decl.progname: tryadd(decl.progname) + if decl.shortname: tryadd(decl.shortname) + if decl.fullname: tryadd(decl.fullname) + + +class TypeCheck: + '''This pass sets the .type attribute of every AST node. For some +nodes, the type is meaningless and it is set to "VOID." This pass +also sets the .decl attribute of AST nodes for which that is relevant; +a decl says where, with what type, and under what name(s) a node was +declared. + +With this information, it finally type checks the AST.''' + + def __init__(self): + # NB: no IPDL compile will EVER print a warning. A program has + # one of two attributes: it is either well typed, or not well typed. + self.errors = [ ] # [ string ] + + def check(self, tu, errout=sys.stderr): + def runpass(tcheckpass): + tu.accept(tcheckpass) + if len(self.errors): + self.reportErrors(errout) + return False + return True + + # tag each relevant node with "decl" information, giving type, name, + # and location of declaration + if not runpass(GatherDecls(builtinUsing, self.errors)): + return False + + # now that the nodes have decls, type checking is much easier. + if not runpass(CheckTypes(self.errors)): + return False + + if not (runpass(BuildProcessGraph(self.errors)) + and runpass(CheckProcessGraph(self.errors))): + return False + + if (tu.protocol + and len(tu.protocol.startStates) + and not runpass(CheckStateMachine(self.errors))): + return False + return True + + def reportErrors(self, errout): + for error in self.errors: + print >>errout, error + + +class TcheckVisitor(Visitor): + def __init__(self, symtab, errors): + self.symtab = symtab + self.errors = errors + + def error(self, loc, fmt, *args): + self.errors.append(errormsg(loc, fmt, *args)) + + def declare(self, loc, type, shortname=None, fullname=None, progname=None): + d = Decl(loc) + d.type = type + d.progname = progname + d.shortname = shortname + d.fullname = fullname + self.symtab.declare(d) + return d + +class GatherDecls(TcheckVisitor): + def __init__(self, builtinUsing, errors): + # |self.symtab| is the symbol table for the translation unit + # currently being visited + TcheckVisitor.__init__(self, None, errors) + self.builtinUsing = builtinUsing + + def visitTranslationUnit(self, tu): + # all TranslationUnits declare symbols in global scope + if hasattr(tu, 'symtab'): + return + tu.symtab = SymbolTable(self.errors) + savedSymtab = self.symtab + self.symtab = tu.symtab + + # pretend like the translation unit "using"-ed these for the + # sake of type checking and C++ code generation + tu.builtinUsing = self.builtinUsing + + # for everyone's sanity, enforce that the filename and tu name + # match + basefilename = os.path.basename(tu.filename) + expectedfilename = '%s.ipdl'% (tu.name) + if not tu.protocol: + # header + expectedfilename += 'h' + if basefilename != expectedfilename: + self.error(tu.loc, + "expected file for translation unit `%s' to be named `%s'; instead it's named `%s'", + tu.name, expectedfilename, basefilename) + + if tu.protocol: + assert tu.name == tu.protocol.name + + p = tu.protocol + + # FIXME/cjones: it's a little weird and counterintuitive + # to put both the namespace and non-namespaced name in the + # global scope. try to figure out something better; maybe + # a type-neutral |using| that works for C++ and protocol + # types? + qname = p.qname() + if 0 == len(qname.quals): + fullname = None + else: + fullname = str(qname) + p.decl = self.declare( + loc=p.loc, + type=ProtocolType(qname, p.nestedRange, p.sendSemantics, + stateless=(0 == len(p.transitionStmts))), + shortname=p.name, + fullname=fullname) + + p.parentEndpointDecl = self.declare( + loc=p.loc, + type=EndpointType(QualifiedId(p.loc, 'Endpoint<' + fullname + 'Parent>', ['mozilla', 'ipc'])), + shortname='Endpoint<' + p.name + 'Parent>') + p.childEndpointDecl = self.declare( + loc=p.loc, + type=EndpointType(QualifiedId(p.loc, 'Endpoint<' + fullname + 'Child>', ['mozilla', 'ipc'])), + shortname='Endpoint<' + p.name + 'Child>') + + # XXX ugh, this sucks. but we need this information to compute + # what friend decls we need in generated C++ + p.decl.type._ast = p + + # make sure we have decls for all dependent protocols + for pinc in tu.includes: + pinc.accept(self) + + # declare imported (and builtin) C++ types + for using in tu.builtinUsing: + using.accept(self) + for using in tu.using: + using.accept(self) + + # first pass to "forward-declare" all structs and unions in + # order to support recursive definitions + for su in tu.structsAndUnions: + self.declareStructOrUnion(su) + + # second pass to check each definition + for su in tu.structsAndUnions: + su.accept(self) + for inc in tu.includes: + if inc.tu.filetype == 'header': + for su in inc.tu.structsAndUnions: + su.accept(self) + + if tu.protocol: + # grab symbols in the protocol itself + p.accept(self) + + + tu.type = VOID + + self.symtab = savedSymtab + + def declareStructOrUnion(self, su): + if hasattr(su, 'decl'): + self.symtab.declare(su.decl) + return + + qname = su.qname() + if 0 == len(qname.quals): + fullname = None + else: + fullname = str(qname) + + if isinstance(su, StructDecl): + sutype = StructType(qname, [ ]) + elif isinstance(su, UnionDecl): + sutype = UnionType(qname, [ ]) + else: assert 0 and 'unknown type' + + # XXX more suckage. this time for pickling structs/unions + # declared in headers. + sutype._ast = su + + su.decl = self.declare( + loc=su.loc, + type=sutype, + shortname=su.name, + fullname=fullname) + + + def visitInclude(self, inc): + if inc.tu is None: + self.error( + inc.loc, + "(type checking here will be unreliable because of an earlier error)") + return + inc.tu.accept(self) + if inc.tu.protocol: + self.symtab.declare(inc.tu.protocol.decl) + self.symtab.declare(inc.tu.protocol.parentEndpointDecl) + self.symtab.declare(inc.tu.protocol.childEndpointDecl) + else: + # This is a header. Import its "exported" globals into + # our scope. + for using in inc.tu.using: + using.accept(self) + for su in inc.tu.structsAndUnions: + self.declareStructOrUnion(su) + + def visitStructDecl(self, sd): + # If we've already processed this struct, don't do it again. + if hasattr(sd, 'symtab'): + return + + stype = sd.decl.type + + self.symtab.enterScope(sd) + + for f in sd.fields: + ftypedecl = self.symtab.lookup(str(f.typespec)) + if ftypedecl is None: + self.error(f.loc, "field `%s' of struct `%s' has unknown type `%s'", + f.name, sd.name, str(f.typespec)) + continue + + f.decl = self.declare( + loc=f.loc, + type=self._canonicalType(ftypedecl.type, f.typespec), + shortname=f.name, + fullname=None) + stype.fields.append(f.decl.type) + + self.symtab.exitScope(sd) + + def visitUnionDecl(self, ud): + utype = ud.decl.type + + # If we've already processed this union, don't do it again. + if len(utype.components): + return + + for c in ud.components: + cdecl = self.symtab.lookup(str(c)) + if cdecl is None: + self.error(c.loc, "unknown component type `%s' of union `%s'", + str(c), ud.name) + continue + utype.components.append(self._canonicalType(cdecl.type, c)) + + def visitUsingStmt(self, using): + fullname = str(using.type) + if using.type.basename() == fullname: + fullname = None + if fullname == 'mozilla::ipc::Shmem': + ipdltype = ShmemType(using.type.spec) + elif fullname == 'mozilla::ipc::FileDescriptor': + ipdltype = FDType(using.type.spec) + else: + ipdltype = ImportedCxxType(using.type.spec) + existingType = self.symtab.lookup(ipdltype.fullname()) + if existingType and existingType.fullname == ipdltype.fullname(): + using.decl = existingType + return + using.decl = self.declare( + loc=using.loc, + type=ipdltype, + shortname=using.type.basename(), + fullname=fullname) + + def visitProtocol(self, p): + # protocol scope + self.symtab.enterScope(p) + + for spawns in p.spawnsStmts: + spawns.accept(self) + + for bridges in p.bridgesStmts: + bridges.accept(self) + + for opens in p.opensStmts: + opens.accept(self) + + seenmgrs = set() + for mgr in p.managers: + if mgr.name in seenmgrs: + self.error(mgr.loc, "manager `%s' appears multiple times", + mgr.name) + continue + + seenmgrs.add(mgr.name) + mgr.of = p + mgr.accept(self) + + for managed in p.managesStmts: + managed.manager = p + managed.accept(self) + + if 0 == len(p.managers) and 0 == len(p.messageDecls): + self.error(p.loc, + "top-level protocol `%s' cannot be empty", + p.name) + + setattr(self, 'currentProtocolDecl', p.decl) + for msg in p.messageDecls: + msg.accept(self) + del self.currentProtocolDecl + + p.decl.type.hasDelete = (not not self.symtab.lookup(_DELETE_MSG)) + if not (p.decl.type.hasDelete or p.decl.type.isToplevel()): + self.error( + p.loc, + "destructor declaration `%s(...)' required for managed protocol `%s'", + _DELETE_MSG, p.name) + + p.decl.type.hasReentrantDelete = p.decl.type.hasDelete and self.symtab.lookup(_DELETE_MSG).type.isInterrupt() + + for managed in p.managesStmts: + mgdname = managed.name + ctordecl = self.symtab.lookup(mgdname +'Constructor') + + if not (ctordecl and ctordecl.type.isCtor()): + self.error( + managed.loc, + "constructor declaration required for managed protocol `%s' (managed by protocol `%s')", + mgdname, p.name) + + p.states = { } + + if len(p.transitionStmts): + p.startStates = [ ts for ts in p.transitionStmts + if ts.state.start ] + if 0 == len(p.startStates): + p.startStates = [ p.transitionStmts[0] ] + + # declare implicit "any", "dead", and "dying" states + self.declare(loc=State.ANY.loc, + type=StateType(p.decl.type, State.ANY.name, start=False), + progname=State.ANY.name) + self.declare(loc=State.DEAD.loc, + type=StateType(p.decl.type, State.DEAD.name, start=False), + progname=State.DEAD.name) + if p.decl.type.hasReentrantDelete: + self.declare(loc=State.DYING.loc, + type=StateType(p.decl.type, State.DYING.name, start=False), + progname=State.DYING.name) + + # declare each state before decorating their mention + for trans in p.transitionStmts: + p.states[trans.state] = trans + trans.state.decl = self.declare( + loc=trans.state.loc, + type=StateType(p.decl.type, trans.state, trans.state.start), + progname=trans.state.name) + + for trans in p.transitionStmts: + self.seentriggers = set() + trans.accept(self) + + if not (p.decl.type.stateless + or (p.decl.type.isToplevel() + and None is self.symtab.lookup(_DELETE_MSG))): + # add a special state |state DEAD: null goto DEAD;| + deadtrans = TransitionStmt.makeNullStmt(State.DEAD) + p.states[State.DEAD] = deadtrans + if p.decl.type.hasReentrantDelete: + dyingtrans = TransitionStmt.makeNullStmt(State.DYING) + p.states[State.DYING] = dyingtrans + + # visit the message decls once more and resolve the state names + # attached to actor params and returns + def resolvestate(loc, actortype): + assert actortype.isIPDL() and actortype.isActor() + + # already resolved this guy's state + if isinstance(actortype.state, Decl): + return + + if actortype.state is None: + # we thought this was a C++ type until type checking, + # when we realized it was an IPDL actor type. But + # that means that the actor wasn't specified to be in + # any particular state + actortype.state = State.ANY + + statename = actortype.state.name + # FIXME/cjones: this is just wrong. we need the symbol table + # of the protocol this actor refers to. low priority bug + # since nobody's using this feature yet + statedecl = self.symtab.lookup(statename) + if statedecl is None: + self.error( + loc, + "protocol `%s' does not have the state `%s'", + actortype.protocol.name(), + statename) + elif not statedecl.type.isState(): + self.error( + loc, + "tag `%s' is supposed to be of state type, but is instead of type `%s'", + statename, + statedecl.type.typename()) + else: + actortype.state = statedecl.type + + for msg in p.messageDecls: + for iparam in msg.inParams: + loc = iparam.loc + for actortype in iteractortypes(iparam.type): + resolvestate(loc, actortype) + for oparam in msg.outParams: + loc = oparam.loc + for actortype in iteractortypes(oparam.type): + resolvestate(loc, actortype) + + # FIXME/cjones declare all the little C++ thingies that will + # be generated. they're not relevant to IPDL itself, but + # those ("invisible") symbols can clash with others in the + # IPDL spec, and we'd like to catch those before C++ compilers + # are allowed to obfuscate the error + + self.symtab.exitScope(p) + + + def visitSpawnsStmt(self, spawns): + pname = spawns.proto + spawns.proto = self.symtab.lookup(pname) + if spawns.proto is None: + self.error(spawns.loc, + "spawned protocol `%s' has not been declared", + pname) + + def visitBridgesStmt(self, bridges): + def lookup(p): + decl = self.symtab.lookup(p) + if decl is None: + self.error(bridges.loc, + "bridged protocol `%s' has not been declared", p) + return decl + bridges.parentSide = lookup(bridges.parentSide) + bridges.childSide = lookup(bridges.childSide) + + def visitOpensStmt(self, opens): + pname = opens.proto + opens.proto = self.symtab.lookup(pname) + if opens.proto is None: + self.error(opens.loc, + "opened protocol `%s' has not been declared", + pname) + + + def visitManager(self, mgr): + mgrdecl = self.symtab.lookup(mgr.name) + pdecl = mgr.of.decl + assert pdecl + + pname, mgrname = pdecl.shortname, mgr.name + loc = mgr.loc + + if mgrdecl is None: + self.error( + loc, + "protocol `%s' referenced as |manager| of `%s' has not been declared", + mgrname, pname) + elif not isinstance(mgrdecl.type, ProtocolType): + self.error( + loc, + "entity `%s' referenced as |manager| of `%s' is not of `protocol' type; instead it is of type `%s'", + mgrname, pname, mgrdecl.type.typename()) + else: + mgr.decl = mgrdecl + pdecl.type.addManager(mgrdecl.type) + + + def visitManagesStmt(self, mgs): + mgsdecl = self.symtab.lookup(mgs.name) + pdecl = mgs.manager.decl + assert pdecl + + pname, mgsname = pdecl.shortname, mgs.name + loc = mgs.loc + + if mgsdecl is None: + self.error(loc, + "protocol `%s', managed by `%s', has not been declared", + mgsname, pname) + elif not isinstance(mgsdecl.type, ProtocolType): + self.error( + loc, + "%s declares itself managing a non-`protocol' entity `%s' of type `%s'", + pname, mgsname, mgsdecl.type.typename()) + else: + mgs.decl = mgsdecl + pdecl.type.manages.append(mgsdecl.type) + + + def visitMessageDecl(self, md): + msgname = md.name + loc = md.loc + + isctor = False + isdtor = False + cdtype = None + + decl = self.symtab.lookup(msgname) + if decl is not None and decl.type.isProtocol(): + # probably a ctor. we'll check validity later. + msgname += 'Constructor' + isctor = True + cdtype = decl.type + elif decl is not None: + self.error(loc, "message name `%s' already declared as `%s'", + msgname, decl.type.typename()) + # if we error here, no big deal; move on to find more + + if _DELETE_MSG == msgname: + isdtor = True + cdtype = self.currentProtocolDecl.type + + + # enter message scope + self.symtab.enterScope(md) + + msgtype = MessageType(md.nested, md.prio, md.sendSemantics, md.direction, + ctor=isctor, dtor=isdtor, cdtype=cdtype, + compress=md.compress, verify=md.verify) + + # replace inparam Param nodes with proper Decls + def paramToDecl(param): + ptname = param.typespec.basename() + ploc = param.typespec.loc + + ptdecl = self.symtab.lookup(ptname) + if ptdecl is None: + self.error( + ploc, + "argument typename `%s' of message `%s' has not been declared", + ptname, msgname) + ptype = VOID + else: + ptype = self._canonicalType(ptdecl.type, param.typespec, + chmodallowed=1) + return self.declare(loc=ploc, + type=ptype, + progname=param.name) + + for i, inparam in enumerate(md.inParams): + pdecl = paramToDecl(inparam) + msgtype.params.append(pdecl.type) + md.inParams[i] = pdecl + for i, outparam in enumerate(md.outParams): + pdecl = paramToDecl(outparam) + msgtype.returns.append(pdecl.type) + md.outParams[i] = pdecl + + self.symtab.exitScope(md) + + md.decl = self.declare( + loc=loc, + type=msgtype, + progname=msgname) + md.protocolDecl = self.currentProtocolDecl + md.decl._md = md + + + def visitTransitionStmt(self, ts): + self.seentriggers = set() + TcheckVisitor.visitTransitionStmt(self, ts) + + def visitTransition(self, t): + loc = t.loc + + # check the trigger message + mname = t.msg + if t in self.seentriggers: + self.error(loc, "trigger `%s' appears multiple times", t.msg) + self.seentriggers.add(t) + + mdecl = self.symtab.lookup(mname) + if mdecl is not None and mdecl.type.isIPDL() and mdecl.type.isProtocol(): + mdecl = self.symtab.lookup(mname +'Constructor') + + if mdecl is None: + self.error(loc, "message `%s' has not been declared", mname) + elif not mdecl.type.isMessage(): + self.error( + loc, + "`%s' should have message type, but instead has type `%s'", + mname, mdecl.type.typename()) + else: + t.msg = mdecl + + # check the to-states + seenstates = set() + for toState in t.toStates: + sname = toState.name + sdecl = self.symtab.lookup(sname) + + if sname in seenstates: + self.error(loc, "to-state `%s' appears multiple times", sname) + seenstates.add(sname) + + if sdecl is None: + self.error(loc, "state `%s' has not been declared", sname) + elif not sdecl.type.isState(): + self.error( + loc, "`%s' should have state type, but instead has type `%s'", + sname, sdecl.type.typename()) + else: + toState.decl = sdecl + toState.start = sdecl.type.start + + t.toStates = set(t.toStates) + + + def _canonicalType(self, itype, typespec, chmodallowed=0): + loc = typespec.loc + + if itype.isIPDL(): + if itype.isProtocol(): + itype = ActorType(itype, + state=typespec.state, + nullable=typespec.nullable) + # FIXME/cjones: ShmemChmod is disabled until bug 524193 + if 0 and chmodallowed and itype.isShmem(): + itype = ShmemChmodType( + itype, + myChmod=typespec.myChmod, + otherChmod=typespec.otherChmod) + + if ((typespec.myChmod or typespec.otherChmod) + and not (itype.isIPDL() and (itype.isShmem() or itype.isChmod()))): + self.error( + loc, + "fine-grained access controls make no sense for type `%s'", + itype.name()) + + if not chmodallowed and (typespec.myChmod or typespec.otherChmod): + self.error(loc, "fine-grained access controls not allowed here") + + if typespec.nullable and not (itype.isIPDL() and itype.isActor()): + self.error( + loc, + "`nullable' qualifier for type `%s' makes no sense", + itype.name()) + + if typespec.array: + itype = ArrayType(itype) + + return itype + + +##----------------------------------------------------------------------------- + +def checkcycles(p, stack=None): + cycles = [] + + if stack is None: + stack = [] + + for cp in p.manages: + # special case for self-managed protocols + if cp is p: + continue + + if cp in stack: + return [stack + [p, cp]] + cycles += checkcycles(cp, stack + [p]) + + return cycles + +def formatcycles(cycles): + r = [] + for cycle in cycles: + s = " -> ".join([ptype.name() for ptype in cycle]) + r.append("`%s'" % s) + return ", ".join(r) + + +def fullyDefined(t, exploring=None): + '''The rules for "full definition" of a type are + defined(atom) := true + defined(array basetype) := defined(basetype) + defined(struct f1 f2...) := defined(f1) and defined(f2) and ... + defined(union c1 c2 ...) := defined(c1) or defined(c2) or ... +''' + if exploring is None: + exploring = set() + + if t.isAtom(): + return True + elif t.isArray(): + return fullyDefined(t.basetype, exploring) + elif t.defined: + return True + assert t.isCompound() + + if t in exploring: + return False + + exploring.add(t) + for c in t.itercomponents(): + cdefined = fullyDefined(c, exploring) + if t.isStruct() and not cdefined: + t.defined = False + break + elif t.isUnion() and cdefined: + t.defined = True + break + else: + if t.isStruct(): t.defined = True + elif t.isUnion(): t.defined = False + exploring.remove(t) + + return t.defined + + +class CheckTypes(TcheckVisitor): + def __init__(self, errors): + # don't need the symbol table, we just want the error reporting + TcheckVisitor.__init__(self, None, errors) + self.visited = set() + self.ptype = None + + def visitInclude(self, inc): + if inc.tu.filename in self.visited: + return + self.visited.add(inc.tu.filename) + if inc.tu.protocol: + inc.tu.protocol.accept(self) + + + def visitStructDecl(self, sd): + if not fullyDefined(sd.decl.type): + self.error(sd.decl.loc, + "struct `%s' is only partially defined", sd.name) + + def visitUnionDecl(self, ud): + if not fullyDefined(ud.decl.type): + self.error(ud.decl.loc, + "union `%s' is only partially defined", ud.name) + + + def visitProtocol(self, p): + self.ptype = p.decl.type + + # check that we require no more "power" than our manager protocols + ptype, pname = p.decl.type, p.decl.shortname + + if len(p.spawnsStmts) and not ptype.isToplevel(): + self.error(p.decl.loc, + "protocol `%s' is not top-level and so cannot declare |spawns|", + pname) + + if len(p.bridgesStmts) and not ptype.isToplevel(): + self.error(p.decl.loc, + "protocol `%s' is not top-level and so cannot declare |bridges|", + pname) + + if len(p.opensStmts) and not ptype.isToplevel(): + self.error(p.decl.loc, + "protocol `%s' is not top-level and so cannot declare |opens|", + pname) + + for mgrtype in ptype.managers: + if mgrtype is not None and ptype.needsMoreJuiceThan(mgrtype): + self.error( + p.decl.loc, + "protocol `%s' requires more powerful send semantics than its manager `%s' provides", + pname, mgrtype.name()) + + # XXX currently we don't require a delete() message of top-level + # actors. need to let experience guide this decision + if not ptype.isToplevel(): + for md in p.messageDecls: + if _DELETE_MSG == md.name: break + else: + self.error( + p.decl.loc, + "managed protocol `%s' requires a `delete()' message to be declared", + p.name) + else: + cycles = checkcycles(p.decl.type) + if cycles: + self.error( + p.decl.loc, + "cycle(s) detected in manager/manages heirarchy: %s", + formatcycles(cycles)) + + if 1 == len(ptype.managers) and ptype is ptype.manager(): + self.error( + p.decl.loc, + "top-level protocol `%s' cannot manage itself", + p.name) + + return Visitor.visitProtocol(self, p) + + + def visitSpawnsStmt(self, spawns): + if not self.ptype.isToplevel(): + self.error(spawns.loc, + "only top-level protocols can have |spawns| statements; `%s' cannot", + self.ptype.name()) + return + + spawnedType = spawns.proto.type + if not (spawnedType.isIPDL() and spawnedType.isProtocol() + and spawnedType.isToplevel()): + self.error(spawns.loc, + "cannot spawn non-top-level-protocol `%s'", + spawnedType.name()) + else: + self.ptype.addSpawn(spawnedType) + + + def visitBridgesStmt(self, bridges): + if not self.ptype.isToplevel(): + self.error(bridges.loc, + "only top-level protocols can have |bridges| statements; `%s' cannot", + self.ptype.name()) + return + + parentType = bridges.parentSide.type + childType = bridges.childSide.type + if not (parentType.isIPDL() and parentType.isProtocol() + and childType.isIPDL() and childType.isProtocol() + and parentType.isToplevel() and childType.isToplevel()): + self.error(bridges.loc, + "cannot bridge non-top-level-protocol(s) `%s' and `%s'", + parentType.name(), childType.name()) + + + def visitOpensStmt(self, opens): + if not self.ptype.isToplevel(): + self.error(opens.loc, + "only top-level protocols can have |opens| statements; `%s' cannot", + self.ptype.name()) + return + + openedType = opens.proto.type + if not (openedType.isIPDL() and openedType.isProtocol() + and openedType.isToplevel()): + self.error(opens.loc, + "cannot open non-top-level-protocol `%s'", + openedType.name()) + else: + self.ptype.addOpen(openedType) + + + def visitManagesStmt(self, mgs): + pdecl = mgs.manager.decl + ptype, pname = pdecl.type, pdecl.shortname + + mgsdecl = mgs.decl + mgstype, mgsname = mgsdecl.type, mgsdecl.shortname + + loc = mgs.loc + + # we added this information; sanity check it + assert ptype.isManagerOf(mgstype) + + # check that the "managed" protocol agrees + if not mgstype.isManagedBy(ptype): + self.error( + loc, + "|manages| declaration in protocol `%s' does not match any |manager| declaration in protocol `%s'", + pname, mgsname) + + + def visitManager(self, mgr): + # FIXME/bug 541126: check that the protocol graph is acyclic + + pdecl = mgr.of.decl + ptype, pname = pdecl.type, pdecl.shortname + + mgrdecl = mgr.decl + mgrtype, mgrname = mgrdecl.type, mgrdecl.shortname + + # we added this information; sanity check it + assert ptype.isManagedBy(mgrtype) + + loc = mgr.loc + + # check that the "manager" protocol agrees + if not mgrtype.isManagerOf(ptype): + self.error( + loc, + "|manager| declaration in protocol `%s' does not match any |manages| declaration in protocol `%s'", + pname, mgrname) + + + def visitMessageDecl(self, md): + mtype, mname = md.decl.type, md.decl.progname + ptype, pname = md.protocolDecl.type, md.protocolDecl.shortname + + loc = md.decl.loc + + if mtype.nested == INSIDE_SYNC_NESTED and not mtype.isSync(): + self.error( + loc, + "inside_sync nested messages must be sync (here, message `%s' in protocol `%s')", + mname, pname) + + if mtype.nested == INSIDE_CPOW_NESTED and (mtype.isOut() or mtype.isInout()): + self.error( + loc, + "inside_cpow nested parent-to-child messages are verboten (here, message `%s' in protocol `%s')", + mname, pname) + + # We allow inside_sync messages that are themselves sync to be sent from the + # parent. Normal and inside_cpow nested messages that are sync can only come from + # the child. + if mtype.isSync() and mtype.nested == NOT_NESTED and (mtype.isOut() or mtype.isInout()): + self.error( + loc, + "sync parent-to-child messages are verboten (here, message `%s' in protocol `%s')", + mname, pname) + + if mtype.needsMoreJuiceThan(ptype): + self.error( + loc, + "message `%s' requires more powerful send semantics than its protocol `%s' provides", + mname, pname) + + if mtype.isAsync() and len(mtype.returns): + # XXX/cjones could modify grammar to disallow this ... + self.error(loc, + "asynchronous message `%s' declares return values", + mname) + + if (mtype.compress and + (not mtype.isAsync() or mtype.isCtor() or mtype.isDtor())): + self.error( + loc, + "message `%s' in protocol `%s' requests compression but is not async or is special (ctor or dtor)", + mname[:-len('constructor')], pname) + + if mtype.isCtor() and not ptype.isManagerOf(mtype.constructedType()): + self.error( + loc, + "ctor for protocol `%s', which is not managed by protocol `%s'", + mname[:-len('constructor')], pname) + + + def visitTransition(self, t): + _YNC = [ ASYNC, SYNC ] + + loc = t.loc + impliedDirection, impliedSems = { + SEND: [ OUT, _YNC ], RECV: [ IN, _YNC ], + CALL: [ OUT, INTR ], ANSWER: [ IN, INTR ], + } [t.trigger] + + if (OUT is impliedDirection and t.msg.type.isIn() + or IN is impliedDirection and t.msg.type.isOut() + or _YNC is impliedSems and t.msg.type.isInterrupt() + or INTR is impliedSems and (not t.msg.type.isInterrupt())): + mtype = t.msg.type + + self.error( + loc, "%s %s message `%s' is not `%s'd", + mtype.sendSemantics.pretty, mtype.direction.pretty, + t.msg.progname, + t.trigger.pretty) + +##----------------------------------------------------------------------------- + +class Process: + def __init__(self): + self.actors = set() # set(Actor) + self.edges = { } # Actor -> [ SpawnsEdge ] + self.spawn = set() # set(Actor) + + def edge(self, spawner, spawn): + if spawner not in self.edges: self.edges[spawner] = [ ] + self.edges[spawner].append(SpawnsEdge(spawner, spawn)) + self.spawn.add(spawn) + + def iteredges(self): + for edgelist in self.edges.itervalues(): + for edge in edgelist: + yield edge + + def merge(self, o): + 'Merge the Process |o| into this Process' + if self == o: + return + for actor in o.actors: + ProcessGraph.actorToProcess[actor] = self + self.actors.update(o.actors) + self.edges.update(o.edges) + self.spawn.update(o.spawn) + ProcessGraph.processes.remove(o) + + def spawns(self, actor): + return actor in self.spawn + + def __cmp__(self, o): return cmp(self.actors, o.actors) + def __eq__(self, o): return self.actors == o.actors + def __hash__(self): return hash(id(self)) + def __repr__(self): + return reduce(lambda a, x: str(a) + str(x) +'|', self.actors, '|') + def __str__(self): return repr(self) + +class Actor: + def __init__(self, ptype, side): + self.ptype = ptype + self.side = side + + def asType(self): + return ActorType(self.ptype) + def other(self): + return Actor(self.ptype, _otherside(self.side)) + + def __cmp__(self, o): + return cmp(self.ptype, o.ptype) or cmp(self.side, o.side) + def __eq__(self, o): + return self.ptype == o.ptype and self.side == o.side + def __hash__(self): return hash(repr(self)) + def __repr__(self): return '%s%s'% (self.ptype.name(), self.side.title()) + def __str__(self): return repr(self) + +class SpawnsEdge: + def __init__(self, spawner, spawn): + self.spawner = spawner # Actor + self.spawn = spawn # Actor + def __repr__(self): + return '(%r)--spawns-->(%r)'% (self.spawner, self.spawn) + def __str__(self): return repr(self) + +class BridgeEdge: + def __init__(self, bridgeProto, parent, child): + self.bridgeProto = bridgeProto # ProtocolType + self.parent = parent # Actor + self.child = child # Actor + def __repr__(self): + return '(%r)--%s bridge-->(%r)'% ( + self.parent, self.bridgeProto.name(), self.child) + def __str__(self): return repr(self) + +class OpensEdge: + def __init__(self, opener, openedProto): + self.opener = opener # Actor + self.openedProto = openedProto # ProtocolType + def __repr__(self): + return '(%r)--opens-->(%s)'% (self.opener, self.openedProto.name()) + def __str__(self): return repr(self) + +# "singleton" class with state that persists across type checking of +# all protocols +class ProcessGraph: + processes = set() # set(Process) + bridges = { } # ProtocolType -> [ BridgeEdge ] + opens = { } # ProtocolType -> [ OpensEdge ] + actorToProcess = { } # Actor -> Process + visitedSpawns = set() # set(ActorType) + visitedBridges = set() # set(ActorType) + + @classmethod + def findProcess(cls, actor): + return cls.actorToProcess.get(actor, None) + + @classmethod + def getProcess(cls, actor): + if actor not in cls.actorToProcess: + p = Process() + p.actors.add(actor) + cls.processes.add(p) + cls.actorToProcess[actor] = p + return cls.actorToProcess[actor] + + @classmethod + def bridgesOf(cls, bridgeP): + return cls.bridges.get(bridgeP, []) + + @classmethod + def bridgeEndpointsOf(cls, ptype, side): + actor = Actor(ptype, side) + endpoints = [] + for b in cls.iterbridges(): + if b.parent == actor: + endpoints.append(Actor(b.bridgeProto, 'parent')) + if b.child == actor: + endpoints.append(Actor(b.bridgeProto, 'child')) + return endpoints + + @classmethod + def iterbridges(cls): + for edges in cls.bridges.itervalues(): + for bridge in edges: + yield bridge + + @classmethod + def opensOf(cls, openedP): + return cls.opens.get(openedP, []) + + @classmethod + def opensEndpointsOf(cls, ptype, side): + actor = Actor(ptype, side) + endpoints = [] + for o in cls.iteropens(): + if actor == o.opener: + endpoints.append(Actor(o.openedProto, o.opener.side)) + elif actor == o.opener.other(): + endpoints.append(Actor(o.openedProto, o.opener.other().side)) + return endpoints + + @classmethod + def iteropens(cls): + for edges in cls.opens.itervalues(): + for opens in edges: + yield opens + + @classmethod + def spawn(cls, spawner, remoteSpawn): + localSpawn = remoteSpawn.other() + spawnerProcess = ProcessGraph.getProcess(spawner) + spawnerProcess.merge(ProcessGraph.getProcess(localSpawn)) + spawnerProcess.edge(spawner, remoteSpawn) + + @classmethod + def bridge(cls, parent, child, bridgeP): + bridgeParent = Actor(bridgeP, 'parent') + parentProcess = ProcessGraph.getProcess(parent) + parentProcess.merge(ProcessGraph.getProcess(bridgeParent)) + bridgeChild = Actor(bridgeP, 'child') + childProcess = ProcessGraph.getProcess(child) + childProcess.merge(ProcessGraph.getProcess(bridgeChild)) + if bridgeP not in cls.bridges: + cls.bridges[bridgeP] = [ ] + cls.bridges[bridgeP].append(BridgeEdge(bridgeP, parent, child)) + + @classmethod + def open(cls, opener, opened, openedP): + remoteOpener, remoteOpened, = opener.other(), opened.other() + openerProcess = ProcessGraph.getProcess(opener) + openerProcess.merge(ProcessGraph.getProcess(opened)) + remoteOpenerProcess = ProcessGraph.getProcess(remoteOpener) + remoteOpenerProcess.merge(ProcessGraph.getProcess(remoteOpened)) + if openedP not in cls.opens: + cls.opens[openedP] = [ ] + cls.opens[openedP].append(OpensEdge(opener, openedP)) + + +class BuildProcessGraph(TcheckVisitor): + class findSpawns(TcheckVisitor): + def __init__(self, errors): + TcheckVisitor.__init__(self, None, errors) + + def visitTranslationUnit(self, tu): + TcheckVisitor.visitTranslationUnit(self, tu) + + def visitInclude(self, inc): + if inc.tu.protocol: + inc.tu.protocol.accept(self) + + def visitProtocol(self, p): + ptype = p.decl.type + # non-top-level protocols don't add any information + if not ptype.isToplevel() or ptype in ProcessGraph.visitedSpawns: + return + + ProcessGraph.visitedSpawns.add(ptype) + self.visiting = ptype + ProcessGraph.getProcess(Actor(ptype, 'parent')) + ProcessGraph.getProcess(Actor(ptype, 'child')) + return TcheckVisitor.visitProtocol(self, p) + + def visitSpawnsStmt(self, spawns): + # The picture here is: + # [ spawner | localSpawn | ??? ] (process 1) + # | + # | + # [ remoteSpawn | ???] (process 2) + # + # A spawns stmt tells us that |spawner| and |localSpawn| + # are in the same process. + spawner = Actor(self.visiting, spawns.side) + remoteSpawn = Actor(spawns.proto.type, spawns.spawnedAs) + ProcessGraph.spawn(spawner, remoteSpawn) + + def __init__(self, errors): + TcheckVisitor.__init__(self, None, errors) + self.visiting = None # ActorType + self.visited = set() # set(ActorType) + + def visitTranslationUnit(self, tu): + tu.accept(self.findSpawns(self.errors)) + TcheckVisitor.visitTranslationUnit(self, tu) + + def visitInclude(self, inc): + if inc.tu.protocol: + inc.tu.protocol.accept(self) + + def visitProtocol(self, p): + ptype = p.decl.type + # non-top-level protocols don't add any information + if not ptype.isToplevel() or ptype in ProcessGraph.visitedBridges: + return + + ProcessGraph.visitedBridges.add(ptype) + self.visiting = ptype + return TcheckVisitor.visitProtocol(self, p) + + def visitBridgesStmt(self, bridges): + bridgeProto = self.visiting + parentSideProto = bridges.parentSide.type + childSideProto = bridges.childSide.type + + # the picture here is: + # (process 1| + # [ parentSide(Parent|Child) | childSide(Parent|Child) | ... ] + # | | + # | (process 2| | + # [ parentSide(Child|Parent) | bridgeParent ] | + # | | + # | | (process 3| + # [ bridgeChild | childSide(Child|Parent) ] + # + # First we have to figure out which parentSide/childSide + # actors live in the same process. The possibilities are { + # parent, child } x { parent, child }. (Multiple matches + # aren't allowed yet.) Then we make ProcessGraph aware of the + # new bridge. + parentSideActor, childSideActor = None, None + pc = ( 'parent', 'child' ) + for parentSide, childSide in cartesian_product(pc, pc): + pactor = Actor(parentSideProto, parentSide) + pproc = ProcessGraph.findProcess(pactor) + cactor = Actor(childSideProto, childSide) + cproc = ProcessGraph.findProcess(cactor) + assert pproc and cproc + + if pproc == cproc: + if parentSideActor is not None: + if parentSideProto != childSideProto: + self.error(bridges.loc, + "ambiguous bridge `%s' between `%s' and `%s'", + bridgeProto.name(), + parentSideProto.name(), + childSideProto.name()) + else: + parentSideActor, childSideActor = pactor.other(), cactor.other() + + if parentSideActor is None: + self.error(bridges.loc, + "`%s' and `%s' cannot be bridged by `%s' ", + parentSideProto.name(), childSideProto.name(), + bridgeProto.name()) + + ProcessGraph.bridge(parentSideActor, childSideActor, bridgeProto) + + def visitOpensStmt(self, opens): + openedP = opens.proto.type + opener = Actor(self.visiting, opens.side) + opened = Actor(openedP, opens.side) + + # The picture here is: + # [ opener | opened ] (process 1) + # | | + # | | + # [ remoteOpener | remoteOpened ] (process 2) + # + # An opens stmt tells us that the pairs |opener|/|opened| + # and |remoteOpener|/|remoteOpened| are each in the same + # process. + ProcessGraph.open(opener, opened, openedP) + + +class CheckProcessGraph(TcheckVisitor): + def __init__(self, errors): + TcheckVisitor.__init__(self, None, errors) + + # TODO: verify spawns-per-process assumption and check that graph + # is a dag + def visitTranslationUnit(self, tu): + if 0: + print 'Processes' + for process in ProcessGraph.processes: + print ' ', process + for edge in process.iteredges(): + print ' ', edge + print 'Bridges' + for bridgeList in ProcessGraph.bridges.itervalues(): + for bridge in bridgeList: + print ' ', bridge + print 'Opens' + for opensList in ProcessGraph.opens.itervalues(): + for opens in opensList: + print ' ', opens + +##----------------------------------------------------------------------------- + +class CheckStateMachine(TcheckVisitor): + def __init__(self, errors): + # don't need the symbol table, we just want the error reporting + TcheckVisitor.__init__(self, None, errors) + self.p = None + + def visitProtocol(self, p): + self.p = p + self.checkReachability(p) + for ts in p.transitionStmts: + ts.accept(self) + + def visitTransitionStmt(self, ts): + # We want to disallow "race conditions" in protocols. These + # can occur when a protocol state machine has a state that + # allows triggers of opposite direction. That declaration + # allows the parent to send the child a message at the + # exact instance the child sends the parent a message. One of + # those messages would (probably) violate the state machine + # and cause the child to be terminated. It's obviously very + # nice if we can forbid this at the level of IPDL state + # machines, rather than resorting to static or dynamic + # checking of C++ implementation code. + # + # An easy way to avoid this problem in IPDL is to only allow + # "unidirectional" protocol states; that is, from each state, + # only send or only recv triggers are allowed. This approach + # is taken by the Singularity project's "contract-based + # message channels." However, this can be something of a + # notational burden for stateful protocols. + # + # If two messages race, the effect is that the parent's and + # child's states get temporarily out of sync. Informally, + # IPDL allows this *only if* the state machines get out of + # sync for only *one* step (state machine transition), then + # sync back up. This is a design decision: the states could + # be allowd to get out of sync for any constant k number of + # steps. (If k is unbounded, there's no point in presenting + # the abstraction of parent and child actor states being + # "entangled".) The working hypothesis is that the more steps + # the states are allowed to be out of sync, the harder it is + # to reason about the protocol. + # + # Slightly less informally, two messages are allowed to race + # only if processing them in either order leaves the protocol + # in the same state. That is, messages A and B are allowed to + # race only if processing A then B leaves the protocol in + # state S, *and* processing B then A also leaves the protocol + # in state S. Technically, if this holds, then messages A and + # B could be called "commutative" wrt to actor state. + # + # "Formally", state machine definitions must adhere to two + # rules. + # + # *Rule 1*: from a state S, all sync triggers must be of the same + # "direction," i.e. only |send| or only |recv| + # + # (Pairs of sync messages can't commute, because otherwise + # deadlock can occur from simultaneously in-flight sync + # requests.) + # + # *Rule 2*: the "Diamond Rule". + # from a state S, + # for any pair of triggers t1 and t2, + # where t1 and t2 have opposite direction, + # and t1 transitions to state T1 and t2 to T2, + # then the following must be true: + # (T2 allows the trigger t1, transitioning to state U) + # and + # (T1 allows the trigger t2, transitioning to state U) + # and + # ( + # ( + # (all of T1's triggers have the same direction as t2) + # and + # (all of T2's triggers have the same direction as t1) + # ) + # or + # (T1, T2, and U are the same "terminal state") + # ) + # + # A "terminal state" S is one from which all triggers + # transition back to S itself. + # + # The presence of triggers with multiple out states complicates + # this check slightly, but doesn't fundamentally change it. + # + # from a state S, + # for any pair of triggers t1 and t2, + # where t1 and t2 have opposite direction, + # for each pair of states (T1, T2) \in t1_out x t2_out, + # where t1_out is the set of outstates from t1 + # t2_out is the set of outstates from t2 + # t1_out x t2_out is their Cartesian product + # and t1 transitions to state T1 and t2 to T2, + # then the following must be true: + # (T2 allows the trigger t1, with out-state set { U }) + # and + # (T1 allows the trigger t2, with out-state set { U }) + # and + # ( + # ( + # (all of T1's triggers have the same direction as t2) + # and + # (all of T2's triggers have the same direction as t1) + # ) + # or + # (T1, T2, and U are the same "terminal state") + # ) + + # check Rule 1 + syncdirection = None + syncok = True + for trans in ts.transitions: + if not trans.msg.type.isSync(): continue + if syncdirection is None: + syncdirection = trans.trigger.direction() + elif syncdirection is not trans.trigger.direction(): + self.error( + trans.loc, + "sync trigger at state `%s' in protocol `%s' has different direction from earlier sync trigger at same state", + ts.state.name, self.p.name) + syncok = False + # don't check the Diamond Rule if Rule 1 doesn't hold + if not syncok: + return + + # helper functions + def triggerTargets(S, t): + '''Return the set of states transitioned to from state |S| +upon trigger |t|, or { } if |t| is not a trigger in |S|.''' + for trans in self.p.states[S].transitions: + if t.trigger is trans.trigger and t.msg is trans.msg: + return trans.toStates + return set() + + def allTriggersSameDirectionAs(S, t): + '''Return true iff all the triggers from state |S| have the same +direction as trigger |t|''' + direction = t.direction() + for trans in self.p.states[S].transitions: + if direction != trans.trigger.direction(): + return False + return True + + def terminalState(S): + '''Return true iff |S| is a "terminal state".''' + for trans in self.p.states[S].transitions: + for S_ in trans.toStates: + if S_ != S: return False + return True + + def sameTerminalState(S1, S2, S3): + '''Return true iff states |S1|, |S2|, and |S3| are all the same +"terminal state".''' + if isinstance(S3, set): + assert len(S3) == 1 + for S3_ in S3: pass + S3 = S3_ + + return (S1 == S2 == S3) and terminalState(S1) + + S = ts.state.name + + # check the Diamond Rule + for (t1, t2) in unique_pairs(ts.transitions): + # if the triggers have the same direction, they can't race, + # since only one endpoint can initiate either (and delivery + # is in-order) + if t1.trigger.direction() == t2.trigger.direction(): + continue + + loc = t1.loc + t1_out = t1.toStates + t2_out = t2.toStates + + for (T1, T2) in cartesian_product(t1_out, t2_out): + # U1 <- { u | T1 --t2--> u } + U1 = triggerTargets(T1, t2) + # U2 <- { u | T2 --t1--> u } + U2 = triggerTargets(T2, t1) + + # don't report more than one Diamond Rule violation + # per state. there may be O(n^4) total, way too many + # for a human to parse + # + # XXX/cjones: could set a limit on #printed and stop + # after that limit ... + raceError = False + errT1 = None + errT2 = None + + if 0 == len(U1) or 0 == len(U2): + print "******* case 1" + raceError = True + elif 1 < len(U1) or 1 < len(U2): + raceError = True + # there are potentially many unpaired states; just + # pick two + print "******* case 2" + for u1, u2 in cartesian_product(U1, U2): + if u1 != u2: + errT1, errT2 = u1, u2 + break + elif U1 != U2: + print "******* case 3" + raceError = True + for errT1 in U1: pass + for errT2 in U2: pass + + if raceError: + self.reportRaceError(loc, S, + [ T1, t1, errT1 ], + [ T2, t2, errT2 ]) + return + + if not ((allTriggersSameDirectionAs(T1, t2.trigger) + and allTriggersSameDirectionAs(T2, t1.trigger)) + or sameTerminalState(T1, T2, U1)): + self.reportRunawayError(loc, S, [ T1, t1, None ], [ T2, t2, None ]) + return + + def checkReachability(self, p): + def explore(ts, visited): + if ts.state in visited: + return + visited.add(ts.state) + for outedge in ts.transitions: + for toState in outedge.toStates: + explore(p.states[toState], visited) + + checkfordelete = (State.DEAD in p.states) + + allvisited = set() # set(State) + for root in p.startStates: + visited = set() + + explore(root, visited) + allvisited.update(visited) + + if checkfordelete and State.DEAD not in visited: + self.error( + root.loc, + "when starting from state `%s', actors of protocol `%s' cannot be deleted", root.state.name, p.name) + + for ts in p.states.itervalues(): + if ts.state is not State.DEAD and ts.state not in allvisited: + self.error(ts.loc, + "unreachable state `%s' in protocol `%s'", + ts.state.name, p.name) + + + def _normalizeTransitionSequences(self, t1Seq, t2Seq): + T1, M1, U1 = t1Seq + T2, M2, U2 = t2Seq + assert M1 is not None and M2 is not None + + # make sure that T1/M1/U1 is the parent side of the race + if M1.trigger is RECV or M1.trigger is ANSWER: + T1, M1, U1, T2, M2, U2 = T2, M2, U2, T1, M1, U1 + + def stateName(S): + if S: return S.name + return '[error]' + + T1 = stateName(T1) + T2 = stateName(T2) + U1 = stateName(U1) + U2 = stateName(U2) + + return T1, M1.msg.progname, U1, T2, M2.msg.progname, U2 + + + def reportRaceError(self, loc, S, t1Seq, t2Seq): + T1, M1, U1, T2, M2, U2 = self._normalizeTransitionSequences(t1Seq, t2Seq) + self.error( + loc, +"""in protocol `%(P)s', the sequence of events + parent: +--`send %(M1)s'-->( state `%(T1)s' )--`recv %(M2)s'-->( state %(U1)s ) + / + ( state `%(S)s' ) + \\ + child: +--`send %(M2)s'-->( state `%(T2)s' )--`recv %(M1)s'-->( state %(U2)s ) +results in error(s) or leaves parent/child state out of sync for more than one step and is thus a race hazard; i.e., triggers `%(M1)s' and `%(M2)s' fail to commute in state `%(S)s'"""% { + 'P': self.p.name, 'S': S, 'M1': M1, 'M2': M2, + 'T1': T1, 'T2': T2, 'U1': U1, 'U2': U2 + }) + + + def reportRunawayError(self, loc, S, t1Seq, t2Seq): + T1, M1, _, T2, M2, __ = self._normalizeTransitionSequences(t1Seq, t2Seq) + self.error( + loc, + """in protocol `%(P)s', the sequence of events + parent: +--`send %(M1)s'-->( state `%(T1)s' ) + / + ( state `%(S)s' ) + \\ + child: +--`send %(M2)s'-->( state `%(T2)s' ) +lead to parent/child states in which parent/child state can become more than one step out of sync (though this divergence might not lead to error conditions)"""% { + 'P': self.p.name, 'S': S, 'M1': M1, 'M2': M2, 'T1': T1, 'T2': T2 + }) |