# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import re
from copy
import deepcopy
from collections
import OrderedDict
import itertools
import ipdl.ast
import ipdl.builtin
from ipdl.cxx.ast
import *
from ipdl.cxx.code
import *
from ipdl.type
import ActorType, UnionType, TypeVisitor, builtinHeaderIncludes
from ipdl.util
import hash_str
# -----------------------------------------------------------------------------
# "Public" interface to lowering
##
class LowerToCxx:
def lower(self, tu, segmentcapacitydict):
"""returns |[ header: File ], [ cpp : File ]| representing the
lowered form of |tu|
"""
# annotate the AST with IPDL/C++ IR-type stuff used later
tu.accept(_DecorateWithCxxStuff())
# Any modifications to the filename scheme here need corresponding
# modifications in the ipdl.py driver script.
name = tu.name
pheader, pcpp = File(name +
".h"), File(name +
".cpp")
_GenerateProtocolCode().lower(tu, pheader, pcpp, segmentcapacitydict)
headers = [pheader]
cpps = [pcpp]
if tu.protocol:
pname = tu.protocol.name
parentheader, parentcpp = (
File(pname +
"Parent.h"),
File(pname +
"Parent.cpp"),
)
_GenerateProtocolParentCode().lower(
tu, pname +
"Parent", parentheader, parentcpp
)
childheader, childcpp = File(pname +
"Child.h"), File(pname +
"Child.cpp")
_GenerateProtocolChildCode().lower(
tu, pname +
"Child", childheader, childcpp
)
headers += [parentheader, childheader]
cpps += [parentcpp, childcpp]
return headers, cpps
# -----------------------------------------------------------------------------
# Helper code
##
def hashfunc(value):
h = hash_str(value) % 2**32
if h < 0:
h += 2**32
return h
_NULL_ACTOR_ID = ExprLiteral.ZERO
_FREED_ACTOR_ID = ExprLiteral.ONE
_DISCLAIMER = Whitespace(
"""//
// Automatically generated by ipdlc.
// Edit at your own risk
//
"""
)
class _struct:
pass
def _namespacedHeaderName(name, namespaces):
pfx =
"/".join([ns.name
for ns
in namespaces])
if pfx:
return pfx +
"/" + name
else:
return name
def _ipdlhHeaderName(tu):
assert tu.filetype ==
"header"
return _namespacedHeaderName(tu.name, tu.namespaces)
def _protocolHeaderName(p, side=
""):
if side:
side = side.title()
base = p.name + side
return _namespacedHeaderName(base, p.namespaces)
def _includeGuardMacroName(headerfile):
return re.sub(r
"[./]",
"_", headerfile.name)
def _includeGuardStart(headerfile):
guard = _includeGuardMacroName(headerfile)
return [CppDirective(
"ifndef", guard), CppDirective(
"define", guard)]
def _includeGuardEnd(headerfile):
guard = _includeGuardMacroName(headerfile)
return [CppDirective(
"endif",
"// ifndef " + guard)]
def _messageStartName(ptype):
return ptype.name() +
"MsgStart"
def _protocolId(ptype):
return ExprVar(_messageStartName(ptype))
def _protocolIdType():
return Type.INT32
def _actorName(pname, side):
"""|pname| is the protocol name. |side| is 'Parent' or 'Child'."""
tag = side
if not tag[0].isupper():
tag = side.title()
return pname + tag
def _actorIdType():
return Type.INT32
def _actorTypeTagType():
return Type.INT32
def _actorId(actor=
None):
if actor
is not None:
return ExprCall(ExprSelect(actor,
"->",
"Id"))
return ExprCall(ExprVar(
"Id"))
def _actorHId(actorhandle):
return ExprSelect(actorhandle,
".",
"mId")
def _deleteId():
return ExprVar(
"Msg___delete____ID")
def _deleteReplyId():
return ExprVar(
"Reply___delete____ID")
def _lookupListener(idexpr):
return ExprCall(ExprVar(
"Lookup"), args=[idexpr])
def _makeForwardDeclForQClass(clsname, quals, cls=
True, struct=
False):
fd = ForwardDecl(clsname, cls=cls, struct=struct)
if 0 == len(quals):
return fd
outerns = Namespace(quals[0])
innerns = outerns
for ns
in quals[1:]:
tmpns = Namespace(ns)
innerns.addstmt(tmpns)
innerns = tmpns
innerns.addstmt(fd)
return outerns
def _makeForwardDeclForActor(ptype, side):
return _makeForwardDeclForQClass(
_actorName(ptype.qname.baseid, side), ptype.qname.quals
)
def _makeForwardDecl(type):
return _makeForwardDeclForQClass(type.name(), type.qname.quals)
def _putInNamespaces(cxxthing, namespaces):
"""|namespaces| is in order [ outer, ..., inner ]"""
if 0 == len(namespaces):
return cxxthing
outerns = Namespace(namespaces[0].name)
innerns = outerns
for ns
in namespaces[1:]:
newns = Namespace(ns.name)
innerns.addstmt(newns)
innerns = newns
innerns.addstmt(cxxthing)
return outerns
def _sendPrefix(msgtype):
"""Prefix of the name of the C++ method that sends |msgtype|."""
return "Send"
def _recvPrefix(msgtype):
"""Prefix of the name of the C++ method that handles |msgtype|."""
return "Recv"
def _flatTypeName(ipdltype):
"""Return a 'flattened' IPDL type name that can be used as an
identifier.
E.g., |Foo[]| --> |ArrayOfFoo|.
"""
# NB: this logic depends heavily on what IPDL types are allowed to
# be constructed; e.g., Foo[][] is disallowed. needs to be kept in
# sync with grammar.
if not ipdltype.isIPDL():
return ipdltype.name()
if ipdltype.isArray():
return "ArrayOf" + _flatTypeName(ipdltype.basetype)
if ipdltype.isMaybe():
return "Maybe" + _flatTypeName(ipdltype.basetype)
# NotNull and UniquePtr types just assume the underlying variant name
# to avoid unnecessary noise, as eg a NotNull<T> and T should never exist
# in the same union.
if ipdltype.isNotNull()
or ipdltype.isUniquePtr():
return _flatTypeName(ipdltype.basetype)
return ipdltype.name()
def _hasVisibleActor(ipdltype):
"""Return true iff a C++ decl of |ipdltype| would have an Actor* type.
For example: |Actor[]| would turn into |Array<ActorParent*>|, so this
function would
return true for |Actor[]|.
"""
return ipdltype.isIPDL()
and (
ipdltype.isActor()
or (ipdltype.hasBaseType()
and _hasVisibleActor(ipdltype.basetype))
)
def _abortIfFalse(cond, msg):
return StmtExpr(
ExprCall(ExprVar(
"MOZ_RELEASE_ASSERT"), [cond, ExprLiteral.String(msg)])
)
def _refptr(T):
return Type(
"RefPtr", T=T)
def _alreadyaddrefed(T):
return Type(
"already_AddRefed", T=T)
def _tuple(types, const=
False, ref=
False):
return Type(
"std::tuple", T=types, const=const, ref=ref)
def _promise(resolvetype, rejecttype, tail, resolver=
False):
inner = Type(
"Private")
if resolver
else None
return Type(
"MozPromise", T=[resolvetype, rejecttype, tail], inner=inner)
def _makePromise(returns, side, resolver=
False):
if len(returns) > 1:
resolvetype = _tuple([d.bareType(side)
for d
in returns])
else:
resolvetype = returns[0].bareType(side)
# MozPromise is purposefully made to be exclusive only. Really, we mean it.
return _promise(
resolvetype, _ResponseRejectReason.Type(), ExprLiteral.
TRUE, resolver=resolver
)
def _resolveType(returns, side):
if len(returns) > 1:
return _tuple([d.inType(side,
"send")
for d
in returns])
return returns[0].inType(side,
"send")
def _makeResolver(returns, side):
return TypeFunction([Decl(_resolveType(returns, side),
"")])
def _cxxArrayType(basetype, const=
False, ref=
False):
return Type(
"nsTArray", T=basetype, const=const, ref=ref, hasimplicitcopyctor=
False)
def _cxxSpanType(basetype, const=
False, ref=
False):
basetype = deepcopy(basetype)
basetype.rightconst =
True
return Type(
"mozilla::Span", T=basetype, const=const, ref=ref, hasimplicitcopyctor=
True
)
def _cxxMaybeType(basetype, const=
False, ref=
False):
return Type(
"mozilla::Maybe",
T=basetype,
const=const,
ref=ref,
hasimplicitcopyctor=basetype.hasimplicitcopyctor,
)
def _cxxReadResultType(basetype, const=
False, ref=
False):
return Type(
"IPC::ReadResult",
T=basetype,
const=const,
ref=ref,
hasimplicitcopyctor=basetype.hasimplicitcopyctor,
)
def _cxxNotNullType(basetype, const=
False, ref=
False):
return Type(
"mozilla::NotNull",
T=basetype,
const=const,
ref=ref,
hasimplicitcopyctor=basetype.hasimplicitcopyctor,
)
def _cxxManagedContainerType(basetype, const=
False, ref=
False):
return Type(
"ManagedContainer", T=basetype, const=const, ref=ref, hasimplicitcopyctor=
False
)
def _cxxLifecycleProxyType(ptr=
False):
return Type(
"mozilla::ipc::ActorLifecycleProxy", ptr=ptr)
def _cxxSide(side):
if side ==
"child":
return ExprVar(
"mozilla::ipc::ChildSide")
if side ==
"parent":
return ExprVar(
"mozilla::ipc::ParentSide")
assert 0
def _otherSide(side):
if side ==
"child":
return "parent"
if side ==
"parent":
return "child"
assert 0
# XXX we need to remove these and install proper error handling
def _printErrorMessage(msg):
if isinstance(msg, str):
msg = ExprLiteral.String(msg)
return StmtExpr(ExprCall(ExprVar(
"NS_ERROR"), args=[msg]))
def _protocolErrorBreakpoint(msg):
if isinstance(msg, str):
msg = ExprLiteral.String(msg)
return StmtExpr(
ExprCall(ExprVar(
"mozilla::ipc::ProtocolErrorBreakpoint"), args=[msg])
)
def _printWarningMessage(msg):
if isinstance(msg, str):
msg = ExprLiteral.String(msg)
return StmtExpr(ExprCall(ExprVar(
"NS_WARNING"), args=[msg]))
def _fatalError(msg):
return StmtExpr(ExprCall(ExprVar(
"FatalError"), args=[ExprLiteral.String(msg)]))
def _logicError(msg):
return StmtExpr(
ExprCall(ExprVar(
"mozilla::ipc::LogicError"), args=[ExprLiteral.String(msg)])
)
def _sentinelReadError(classname):
return StmtExpr(
ExprCall(
ExprVar(
"mozilla::ipc::SentinelReadError"),
args=[ExprLiteral.String(classname)],
)
)
identifierRegExp = re.compile(
"[a-zA-Z_][a-zA-Z0-9_]*")
def _validCxxIdentifier(name):
return identifierRegExp.fullmatch(name)
is not None
# Results that IPDL-generated code returns back to *Channel code.
# Users never see these
class _Result:
@staticmethod
def Type():
return Type(
"Result")
Processed = ExprVar(
"MsgProcessed")
Dropped = ExprVar(
"MsgDropped")
NotKnown = ExprVar(
"MsgNotKnown")
NotAllowed = ExprVar(
"MsgNotAllowed")
PayloadError = ExprVar(
"MsgPayloadError")
ProcessingError = ExprVar(
"MsgProcessingError")
ValuError = ExprVar(
"MsgValueError")
# [sic]
# these |errfn*| are functions that generate code to be executed on an
# error, such as "bad actor ID". each is given a Python string
# containing a description of the error
# used in user-facing Send*() methods
def errfnSend(msg, errcode=ExprLiteral.
FALSE):
return [_fatalError(msg), StmtReturn(errcode)]
def errfnSendCtor(msg):
return errfnSend(msg, errcode=ExprLiteral.NULL)
# TODO should this error handling be strengthened for dtors?
def errfnSendDtor(msg):
return [_printErrorMessage(msg), StmtReturn.
FALSE]
# used in |OnMessage*()| handlers that hand in-messages off to Recv*()
# interface methods
def errfnRecv(msg, errcode=_Result.ValuError):
return [_fatalError(msg), StmtReturn(errcode)]
def errfnSentinel(rvalue=ExprLiteral.
FALSE):
def inner(msg):
return [_sentinelReadError(msg), StmtReturn(rvalue)]
return inner
def errfnUnreachable(msg):
return [_logicError(msg)]
def readResultError():
return ExprCode(
"{}")
class _DestroyReason:
@staticmethod
def Type():
return Type(
"ActorDestroyReason")
Deletion = ExprVar(
"Deletion")
AncestorDeletion = ExprVar(
"AncestorDeletion")
NormalShutdown = ExprVar(
"NormalShutdown")
AbnormalShutdown = ExprVar(
"AbnormalShutdown")
FailedConstructor = ExprVar(
"FailedConstructor")
ManagedEndpointDropped = ExprVar(
"ManagedEndpointDropped")
class _ResponseRejectReason:
@staticmethod
def Type():
return Type(
"ResponseRejectReason")
SendError = ExprVar(
"ResponseRejectReason::SendError")
ChannelClosed = ExprVar(
"ResponseRejectReason::ChannelClosed")
HandlerRejected = ExprVar(
"ResponseRejectReason::HandlerRejected")
ActorDestroyed = ExprVar(
"ResponseRejectReason::ActorDestroyed")
# -----------------------------------------------------------------------------
# Intermediate representation (IR) nodes used during lowering
class _ConvertToCxxType(TypeVisitor):
def __init__(self, side, fq):
self.side = side
self.fq = fq
def typename(self, thing):
if self.fq:
return thing.fullname()
return thing.name()
def visitImportedCxxType(self, t):
cxxtype = Type(self.typename(t))
if t.isRefcounted():
cxxtype = _refptr(cxxtype)
return cxxtype
def visitBuiltinCType(self, b):
return Type(self.typename(b))
def visitActorType(self, a):
if self.side
is None:
return Type(
"::mozilla::ipc::SideVariant",
T=[
_cxxBareType(a,
"parent", self.fq),
_cxxBareType(a,
"child", self.fq),
],
)
return Type(_actorName(self.typename(a.protocol), self.side), ptr=
True)
def visitStructType(self, s):
return Type(self.typename(s))
def visitUnionType(self, u):
return Type(self.typename(u))
def visitArrayType(self, a):
basecxxtype = a.basetype.accept(self)
return _cxxArrayType(basecxxtype)
def visitMaybeType(self, m):
basecxxtype = m.basetype.accept(self)
return _cxxMaybeType(basecxxtype)
def visitShmemType(self, s):
return Type(self.typename(s))
def visitByteBufType(self, s):
return Type(self.typename(s))
def visitFDType(self, s):
return Type(self.typename(s))
def visitEndpointType(self, s):
return Type(self.typename(s))
def visitManagedEndpointType(self, s):
return Type(self.typename(s))
def visitUniquePtrType(self, s):
return Type(self.typename(s))
def visitNotNullType(self, n):
basecxxtype = n.basetype.accept(self)
return _cxxNotNullType(basecxxtype)
def visitProtocolType(self, p):
assert 0
def visitMessageType(self, m):
assert 0
def visitVoidType(self, v):
assert 0
def _cxxBareType(ipdltype, side, fq=
False):
return ipdltype.accept(_ConvertToCxxType(side, fq))
def _cxxRefType(ipdltype, side):
t = _cxxBareType(ipdltype, side)
t.ref =
True
return t
def _cxxConstRefType(ipdltype, side):
t = _cxxBareType(ipdltype, side)
if ipdltype.isIPDL()
and ipdltype.isActor():
return t
if ipdltype.isIPDL()
and ipdltype.isShmem():
t.ref =
True
return t
if ipdltype.isIPDL()
and ipdltype.isNotNull():
# If the inner type chooses to use a raw pointer, wrap that instead.
inner = _cxxConstRefType(ipdltype.basetype, side)
if inner.ptr:
t = _cxxNotNullType(inner)
return t
if ipdltype.isIPDL()
and ipdltype.hasBaseType():
# Keep same constness as inner type.
inner = _cxxConstRefType(ipdltype.basetype, side)
t.const = inner.const
or not inner.ref
t.ref =
True
return t
if ipdltype.isCxx()
and (ipdltype.isSendMoveOnly()
or ipdltype.isDataMoveOnly()):
t.const =
True
t.ref =
True
return t
if ipdltype.isCxx()
and ipdltype.isRefcounted():
# Use T* instead of const RefPtr<T>&
t = t.T
t.ptr =
True
return t
t.const =
True
t.ref =
True
return t
def _cxxTypeNeedsMoveForSend(ipdltype, context=
"root", visited=
None):
"""Returns `True` if serializing ipdltype requires a mutable reference, e.g.
because the underlying resource represented by the value
is being
transferred to another process. This
is occasionally distinct
from whether
the C++ type exposes a copy constructor, such
as for types which are
not
cheaply copiable, but are
not mutated when serialized.
"""
if visited
is None:
visited = set()
visited.add(ipdltype)
if ipdltype.isCxx():
return ipdltype.isSendMoveOnly()
if ipdltype.isIPDL():
if ipdltype.hasBaseType():
return _cxxTypeNeedsMoveForSend(ipdltype.basetype,
"wrapper", visited)
if ipdltype.isStruct()
or ipdltype.isUnion():
return any(
_cxxTypeNeedsMoveForSend(t,
"compound", visited)
for t
in ipdltype.itercomponents()
if t
not in visited
)
# For historical reasons, shmem is `const_cast` to a mutable reference
# when being stored in a struct or union (see
# `_StructField.constRefExpr` and `_UnionMember.getConstValue`), meaning
# that they do not cause the containing struct to require move for
# sending.
if ipdltype.isShmem():
return context !=
"compound"
return (
ipdltype.isByteBuf()
or ipdltype.isEndpoint()
or ipdltype.isManagedEndpoint()
)
return False
def _cxxTypeNeedsMoveForData(ipdltype, context=
"root", visited=
None):
"""Returns `True` if the bare C++ type corresponding to ipdltype does not
satisfy std::is_copy_constructible_v<T>. All C++ types supported by IPDL
must support std::is_move_constructible_v<T>, so non-movable types must be
passed behind a `UniquePtr`.
"""
if visited
is None:
visited = set()
visited.add(ipdltype)
if ipdltype.isCxx():
return ipdltype.isDataMoveOnly()
if ipdltype.isIPDL():
if ipdltype.isUniquePtr():
return True
# When nested within a maybe or array, arrays are no longer copyable.
if context ==
"wrapper" and ipdltype.isArray():
return True
if ipdltype.hasBaseType():
return _cxxTypeNeedsMoveForData(ipdltype.basetype,
"wrapper", visited)
if ipdltype.isStruct()
or ipdltype.isUnion():
return any(
_cxxTypeNeedsMoveForData(t,
"compound", visited)
for t
in ipdltype.itercomponents()
if t
not in visited
)
return (
ipdltype.isByteBuf()
or ipdltype.isEndpoint()
or ipdltype.isManagedEndpoint()
)
return False
def _cxxTypeCanMove(ipdltype):
return not (ipdltype.isIPDL()
and ipdltype.isActor())
def _cxxForceMoveRefType(ipdltype, side):
assert _cxxTypeCanMove(ipdltype)
t = _cxxBareType(ipdltype, side)
t.rvalref =
True
return t
def _cxxPtrToType(ipdltype, side):
t = _cxxBareType(ipdltype, side)
if ipdltype.isIPDL()
and ipdltype.isActor()
and side
is not None:
t.ptr =
False
t.ptrptr =
True
return t
t.ptr =
True
return t
def _cxxConstPtrToType(ipdltype, side):
t = _cxxBareType(ipdltype, side)
if ipdltype.isIPDL()
and ipdltype.isActor()
and side
is not None:
t.ptr =
False
t.ptrconstptr =
True
return t
t.const =
True
t.ptr =
True
return t
def _cxxInType(ipdltype, side, direction):
t = _cxxBareType(ipdltype, side)
if ipdltype.isIPDL()
and ipdltype.isActor():
return t
if ipdltype.isIPDL()
and ipdltype.isNotNull():
# If the inner type chooses to use a raw pointer, wrap that instead.
inner = _cxxInType(ipdltype.basetype, side, direction)
if inner.ptr:
t = _cxxNotNullType(inner)
return t
if _cxxTypeNeedsMoveForSend(ipdltype):
t.rvalref =
True
return t
if ipdltype.isCxx():
if ipdltype.isRefcounted():
# Use T* instead of const RefPtr<T>&
t = t.T
t.ptr =
True
return t
if ipdltype.name() ==
"nsCString":
t = Type(
"nsACString")
if ipdltype.name() ==
"nsString":
t = Type(
"nsAString")
# Use Span<T const> rather than nsTArray<T> for array types which aren't
# `_cxxTypeNeedsMoveForSend`. This is only done for the "send" side, and not
# for recv signatures.
if direction ==
"send" and ipdltype.isIPDL()
and ipdltype.isArray():
inner = _cxxBareType(ipdltype.basetype, side)
return _cxxSpanType(inner)
t.const =
True
t.ref =
True
return t
def _allocMethod(ptype, side):
return "Alloc" + ptype.name() + side.title()
def _deallocMethod(ptype, side):
return "Dealloc" + ptype.name() + side.title()
##
# A _HybridDecl straddles IPDL and C++ decls. It knows which C++
# types correspond to which IPDL types, and it also knows how
# serialize and deserialize "special" IPDL C++ types.
##
class _HybridDecl:
"""A hybrid decl stores both an IPDL type and all the C++ type
info needed by later passes, along
with a basic name
for the decl.
"""
def __init__(self, ipdltype, name, attributes={}):
self.ipdltype = ipdltype
self.name = name
self.attributes = attributes
def var(self):
return ExprVar(self.name)
def bareType(self, side, fq=
False):
"""Return this decl's unqualified C++ type."""
return _cxxBareType(self.ipdltype, side, fq=fq)
def refType(self, side):
"""Return this decl's C++ type as a 'reference' type, which is not
necessarily a C++ reference.
"""
return _cxxRefType(self.ipdltype, side)
def constRefType(self, side):
"""Return this decl's C++ type as a const, 'reference' type."""
return _cxxConstRefType(self.ipdltype, side)
def ptrToType(self, side):
return _cxxPtrToType(self.ipdltype, side)
def constPtrToType(self, side):
return _cxxConstPtrToType(self.ipdltype, side)
def inType(self, side, direction):
"""Return this decl's C++ Type with sending inparam semantics."""
return _cxxInType(self.ipdltype, side, direction)
def outType(self, side):
"""Return this decl's C++ Type with outparam semantics."""
t = self.bareType(side)
if self.ipdltype.isIPDL()
and self.ipdltype.isActor():
t.ptr =
False
t.ptrptr =
True
return t
t.ptr =
True
return t
def forceMoveType(self, side):
"""Return this decl's C++ Type with forced move semantics."""
assert _cxxTypeCanMove(self.ipdltype)
return _cxxForceMoveRefType(self.ipdltype, side)
# --------------------------------------------------
class HasFQName:
def fqClassName(self):
return self.decl.type.fullname()
class _CompoundTypeComponent(_HybridDecl):
# @override the following methods to make the side argument optional.
def bareType(self, side=
None, fq=
False):
return _HybridDecl.bareType(self, side, fq=fq)
def refType(self, side=
None):
return _HybridDecl.refType(self, side)
def constRefType(self, side=
None):
return _HybridDecl.constRefType(self, side)
def ptrToType(self, side=
None):
return _HybridDecl.ptrToType(self, side)
def constPtrToType(self, side=
None):
return _HybridDecl.constPtrToType(self, side)
def forceMoveType(self, side=
None):
return _HybridDecl.forceMoveType(self, side)
class StructDecl(ipdl.ast.StructDecl, HasFQName):
def fields_ipdl_order(self):
for f
in self.fields:
yield f
def fields_member_order(self):
assert len(self.packed_field_order) == len(self.fields)
for i
in self.packed_field_order:
yield self.fields[i]
@staticmethod
def upgrade(structDecl):
assert isinstance(structDecl, ipdl.ast.StructDecl)
structDecl.__class__ = StructDecl
class _StructField(_CompoundTypeComponent):
def __init__(self, ipdltype, name, sd):
self.basename = name
_CompoundTypeComponent.__init__(self, ipdltype, name)
def getMethod(self, thisexpr=
None, sel=
"."):
meth = self.var()
if thisexpr
is not None:
return ExprSelect(thisexpr, sel, meth.name)
return meth
def refExpr(self, thisexpr=
None):
ref = self.memberVar()
if thisexpr
is not None:
ref = ExprSelect(thisexpr,
".", ref.name)
return ref
def constRefExpr(self, thisexpr=
None):
# sigh, gross hack
refexpr = self.refExpr(thisexpr)
if "Shmem" == self.ipdltype.name():
refexpr = ExprCast(refexpr, Type(
"Shmem", ref=
True), const=
True)
return refexpr
def argVar(self):
return ExprVar(
"_" + self.name)
def memberVar(self):
return ExprVar(self.name +
"_")
class UnionDecl(ipdl.ast.UnionDecl, HasFQName):
def callType(self, var=
None):
func = ExprVar(
"type")
if var
is not None:
func = ExprSelect(var,
".", func.name)
return ExprCall(func)
@staticmethod
def upgrade(unionDecl):
assert isinstance(unionDecl, ipdl.ast.UnionDecl)
unionDecl.__class__ = UnionDecl
class _UnionMember(_CompoundTypeComponent):
"""Not in the AFL sense, but rather a member (e.g. |int;|) of an
IPDL union type.
"""
def __init__(self, ipdltype, ud):
flatname = _flatTypeName(ipdltype)
assert _validCxxIdentifier(flatname)
_CompoundTypeComponent.__init__(self, ipdltype,
"mV" + flatname)
self.flattypename = flatname
# To create a finite object with a mutually recursive type, a union must
# be present somewhere in the recursive loop. Because of that we only
# need to care about introducing indirections inside unions.
self.recursive = ud.decl.type.mutuallyRecursiveWith(ipdltype)
def enum(self):
return "T" + self.flattypename
def enumvar(self):
return ExprVar(self.enum())
def internalType(self):
if self.recursive:
return self.ptrToType()
else:
return self.bareType()
def unionType(self):
"""Type used for storage in generated C union decl."""
if self.recursive:
return self.ptrToType()
else:
return self.internalType()
def unionValue(self):
return ExprVar(self.name)
def typedef(self):
return self.flattypename +
"__tdef"
def callGetConstPtr(self):
"""Return an expression of type self.constptrToSelfType()"""
return ExprCall(ExprVar(self.getConstPtrName()))
def callGetPtr(self):
"""Return an expression of type self.ptrToSelfType()"""
return ExprCall(ExprVar(self.getPtrName()))
def callCtor(self, expr=
None):
assert not isinstance(expr, list)
if expr
is None:
args =
None
elif (
self.ipdltype.isIPDL()
and self.ipdltype.isArray()
and not isinstance(expr, ExprMove)
):
args = [ExprCall(ExprSelect(expr,
".",
"Clone"), args=[])]
else:
args = [expr]
if self.recursive:
return ExprAssn(self.callGetPtr(), ExprNew(self.bareType(), args=args))
else:
return ExprNew(
self.bareType(),
args=args,
newargs=[ExprVar(
"mozilla::KnownNotNull"), self.callGetPtr()],
)
def callDtor(self):
if self.recursive:
return ExprDelete(self.callGetPtr())
else:
return ExprCall(ExprSelect(self.callGetPtr(),
"->",
"~" + self.typedef()))
def getTypeName(self):
return "get_" + self.flattypename
def getConstTypeName(self):
return "get_" + self.flattypename
def getOtherTypeName(self):
return "get_" + self.otherflattypename
def getPtrName(self):
return "ptr_" + self.flattypename
def getConstPtrName(self):
return "constptr_" + self.flattypename
def ptrToSelfExpr(self):
"""|*ptrToSelfExpr()| has type |self.bareType()|"""
v = self.unionValue()
if self.recursive:
return v
else:
return ExprAddrOf(v)
def constptrToSelfExpr(self):
"""|*constptrToSelfExpr()| has type |self.constType()|"""
return self.ptrToSelfExpr()
def ptrToInternalType(self):
t = self.ptrToType()
if self.recursive:
t.ref =
True
return t
def defaultValue(self, fq=
False):
# Use the default constructor for any class that does not have an
# implicit copy constructor.
if not self.bareType().hasimplicitcopyctor:
return None
if self.ipdltype.isIPDL()
and self.ipdltype.isActor():
return ExprLiteral.NULL
# XXX sneaky here, maybe need ExprCtor()?
return ExprCall(self.bareType(fq=fq))
def getConstValue(self):
v = ExprDeref(self.callGetConstPtr())
# sigh
if "Shmem" == self.ipdltype.name():
v = ExprCast(v, Type(
"Shmem", ref=
True), const=
True)
return v
# --------------------------------------------------
class MessageDecl(ipdl.ast.MessageDecl):
def baseName(self):
return self.name
def recvMethod(self):
name = _recvPrefix(self.decl.type) + self.baseName()
if self.decl.type.isCtor():
name +=
"Constructor"
return name
def sendMethod(self):
name = _sendPrefix(self.decl.type) + self.baseName()
if self.decl.type.isCtor():
name +=
"Constructor"
return name
def hasReply(self):
return (
self.decl.type.hasReply()
or self.decl.type.isCtor()
or self.decl.type.isDtor()
)
def hasAsyncReturns(self):
return self.decl.type.isAsync()
and self.returns
def msgCtorFunc(self):
return "Msg_%s" % (self.decl.progname)
def prettyMsgName(self, pfx=
""):
return pfx + self.msgCtorFunc()
def pqMsgCtorFunc(self):
return "%s::%s" % (self.namespace, self.msgCtorFunc())
def msgId(self):
return self.msgCtorFunc() +
"__ID"
def pqMsgId(self):
return "%s::%s" % (self.namespace, self.msgId())
def replyCtorFunc(self):
return "Reply_%s" % (self.decl.progname)
def pqReplyCtorFunc(self):
return "%s::%s" % (self.namespace, self.replyCtorFunc())
def replyId(self):
return self.replyCtorFunc() +
"__ID"
def pqReplyId(self):
return "%s::%s" % (self.namespace, self.replyId())
def prettyReplyName(self, pfx=
""):
return pfx + self.replyCtorFunc()
def promiseName(self):
name = self.baseName()
if self.decl.type.isCtor():
name +=
"Constructor"
name +=
"Promise"
return name
def resolverName(self):
return self.baseName() +
"Resolver"
def actorDecl(self):
return self.params[0]
def makeCxxParams(
self, paramsems=
"in", returnsems=
"out", side=
None, implicit=
True, direction=
None
):
"""Return a list of C++ decls per the spec'd configuration.
|params|
and |returns|
is the C++ semantics of those:
'in',
'out',
or None.
"""
def makeDecl(d, sems):
if (
self.decl.type.tainted
and "NoTaint" not in d.attributes
and direction ==
"recv"
):
# Tainted types are passed by-value, allowing the receiver to move them if desired.
assert sems !=
"out"
return Decl(Type(
"Tainted", T=d.bareType(side)), d.name)
if sems ==
"in":
t = d.inType(side, direction)
# If this is the `recv` side, and we're not using "move"
# semantics, that means we're an alloc method, and cannot accept
# values by rvalue reference. Downgrade to an lvalue reference.
if direction ==
"recv" and t.rvalref:
t.rvalref =
False
t.ref =
True
return Decl(t, d.name)
elif sems ==
"move":
assert direction ==
"recv"
# For legacy reasons, use an rvalue reference when generating
# parameters for recv methods which accept arrays.
if d.ipdltype.isIPDL()
and d.ipdltype.isArray():
t = d.bareType(side)
t.rvalref =
True
return Decl(t, d.name)
return Decl(d.inType(side, direction), d.name)
elif sems ==
"out":
return Decl(d.outType(side), d.name)
else:
assert 0
def makeResolverDecl(returns):
return Decl(Type(self.resolverName(), rvalref=
True),
"aResolve")
def makeCallbackResolveDecl(returns):
if len(returns) > 1:
resolvetype = _tuple([d.bareType(side)
for d
in returns])
else:
resolvetype = returns[0].bareType(side)
return Decl(
Type(
"mozilla::ipc::ResolveCallback", T=resolvetype, rvalref=
True),
"aResolve",
)
def makeCallbackRejectDecl(returns):
return Decl(Type(
"mozilla::ipc::RejectCallback", rvalref=
True),
"aReject")
cxxparams = []
if paramsems
is not None:
cxxparams.extend([makeDecl(d, paramsems)
for d
in self.params])
if returnsems ==
"promise" and self.returns:
pass
elif returnsems ==
"callback" and self.returns:
cxxparams.extend(
[
makeCallbackResolveDecl(self.returns),
makeCallbackRejectDecl(self.returns),
]
)
elif returnsems ==
"resolver" and self.returns:
cxxparams.extend([makeResolverDecl(self.returns)])
elif returnsems
is not None:
cxxparams.extend([makeDecl(r, returnsems)
for r
in self.returns])
if not implicit
and self.decl.type.hasImplicitActorParam():
cxxparams = cxxparams[1:]
return cxxparams
def makeCxxArgs(
self, paramsems=
"in", retsems=
"out", retcallsems=
"out", implicit=
True
):
assert not retcallsems
or retsems
# retcallsems => returnsems
cxxargs = []
if paramsems ==
"move":
# We don't std::move() RefPtr<T> types because current Recv*()
# implementors take these parameters as T*, and
# std::move(RefPtr<T>) doesn't coerce to T*.
# We also don't move NotNull, as it has no move constructor.
cxxargs.extend(
[
(
p.var()
if p.ipdltype.isRefcounted()
or (p.ipdltype.isIPDL()
and p.ipdltype.isNotNull())
else ExprMove(p.var())
)
for p
in self.params
]
)
elif paramsems ==
"in":
cxxargs.extend([p.var()
for p
in self.params])
else:
assert False
for ret
in self.returns:
if retsems ==
"in":
if retcallsems ==
"in":
cxxargs.append(ret.var())
elif retcallsems ==
"out":
cxxargs.append(ExprAddrOf(ret.var()))
else:
assert 0
elif retsems ==
"out":
if retcallsems ==
"in":
cxxargs.append(ExprDeref(ret.var()))
elif retcallsems ==
"out":
cxxargs.append(ret.var())
else:
assert 0
elif retsems ==
"resolver":
pass
if retsems ==
"resolver":
cxxargs.append(ExprMove(ExprVar(
"resolver")))
if not implicit:
assert self.decl.type.hasImplicitActorParam()
cxxargs = cxxargs[1:]
return cxxargs
@staticmethod
def upgrade(messageDecl):
assert isinstance(messageDecl, ipdl.ast.MessageDecl)
if messageDecl.decl.type.hasImplicitActorParam():
messageDecl.params.insert(
0,
_HybridDecl(
ipdl.type.ActorType(messageDecl.decl.type.constructedType()),
"actor",
),
)
messageDecl.__class__ = MessageDecl
# --------------------------------------------------
def _usesShmem(p):
for md
in p.messageDecls:
for param
in md.inParams:
if ipdl.type.hasshmem(param.type):
return True
for ret
in md.outParams:
if ipdl.type.hasshmem(ret.type):
return True
return False
def _subtreeUsesShmem(p):
if _usesShmem(p):
return True
ptype = p.decl.type
for mgd
in ptype.manages:
if ptype
is not mgd:
if _subtreeUsesShmem(mgd._ast):
return True
return False
class Protocol(ipdl.ast.Protocol):
def _ipdlmgrtype(self):
assert 1 == len(self.decl.type.managers)
for mgr
in self.decl.type.managers:
return mgr
def managerActorType(self, side, ptr=
False):
return Type(_actorName(self._ipdlmgrtype().name(), side), ptr=ptr)
def unregisterMethod(self, actorThis=
None):
if actorThis
is not None:
return ExprSelect(actorThis,
"->",
"Unregister")
return ExprVar(
"Unregister")
def removeManageeMethod(self):
return ExprVar(
"RemoveManagee")
def deallocManageeMethod(self):
return ExprVar(
"DeallocManagee")
def getChannelMethod(self):
return ExprVar(
"GetIPCChannel")
def callGetChannel(self, actorThis=
None):
fn = self.getChannelMethod()
if actorThis
is not None:
fn = ExprSelect(actorThis,
"->", fn.name)
return ExprCall(fn)
def processingErrorVar(self):
assert self.decl.type.isToplevel()
return ExprVar(
"ProcessingError")
def shouldContinueFromTimeoutVar(self):
assert self.decl.type.isToplevel()
return ExprVar(
"ShouldContinueFromReplyTimeout")
def routingId(self, actorThis=
None):
if self.decl.type.isToplevel():
return ExprVar(
"MSG_ROUTING_CONTROL")
if actorThis
is not None:
return ExprCall(ExprSelect(actorThis,
"->",
"Id"))
return ExprCall(ExprVar(
"Id"))
def managerVar(self, thisexpr=
None):
assert thisexpr
is not None or not self.decl.type.isToplevel()
mvar = ExprCall(ExprVar(
"Manager"), args=[])
if thisexpr
is not None:
mvar = ExprCall(ExprSelect(thisexpr,
"->",
"Manager"), args=[])
return mvar
def managedCxxType(self, actortype, side):
assert self.decl.type.isManagerOf(actortype)
return Type(_actorName(actortype.name(), side), ptr=
True)
def managedMethod(self, actortype, side):
assert self.decl.type.isManagerOf(actortype)
return ExprVar(
"Managed" + _actorName(actortype.name(), side))
def managedVar(self, actortype, side):
assert self.decl.type.isManagerOf(actortype)
return ExprVar(
"mManaged" + _actorName(actortype.name(), side))
def managedVarType(self, actortype, side, const=
False, ref=
False):
assert self.decl.type.isManagerOf(actortype)
return _cxxManagedContainerType(
Type(_actorName(actortype.name(), side)), const=const, ref=ref
)
def subtreeUsesShmem(self):
return _subtreeUsesShmem(self)
@staticmethod
def upgrade(protocol):
assert isinstance(protocol, ipdl.ast.Protocol)
protocol.__class__ = Protocol
class TranslationUnit(ipdl.ast.TranslationUnit):
@staticmethod
def upgrade(tu):
assert isinstance(tu, ipdl.ast.TranslationUnit)
tu.__class__ = TranslationUnit
# -----------------------------------------------------------------------------
pod_types = {
"::int8_t": 1,
"::uint8_t": 1,
"::int16_t": 2,
"::uint16_t": 2,
"::int32_t": 4,
"::uint32_t": 4,
"::int64_t": 8,
"::uint64_t": 8,
"float": 4,
"double": 8,
}
max_pod_size = max(pod_types.values())
# We claim that all types we don't recognize are automatically "bigger"
# than pod types for ease of sorting.
pod_size_sentinel = max_pod_size * 2
def pod_size(ipdltype):
if not ipdltype.isCxx():
return pod_size_sentinel
return pod_types.get(ipdltype.fullname(), pod_size_sentinel)
class _DecorateWithCxxStuff(ipdl.ast.Visitor):
"""Phase 1 of lowering: decorate the IPDL AST with information
relevant to C++ code generation.
This
pass results
in an AST that
is a poor man
's "IR"; in reality, a
"hybrid" AST mainly consisting of IPDL nodes
with new C++ info along
with some new IPDL/C++ nodes that are tuned
for C++ codegen.
"""
def __init__(self):
self.visitedTus = set()
self.protocolName =
None
def visitTranslationUnit(self, tu):
if tu
not in self.visitedTus:
self.visitedTus.add(tu)
ipdl.ast.Visitor.visitTranslationUnit(self, tu)
if not isinstance(tu, TranslationUnit):
TranslationUnit.upgrade(tu)
def visitInclude(self, inc):
if inc.tu.filetype ==
"header":
inc.tu.accept(self)
def visitProtocol(self, pro):
self.protocolName = pro.name
Protocol.upgrade(pro)
return ipdl.ast.Visitor.visitProtocol(self, pro)
def visitStructDecl(self, sd):
if not isinstance(sd, StructDecl):
newfields = [_StructField(f.decl.type, f.name, sd)
for f
in sd.fields]
# Compute a permutation of the fields for in-memory storage such
# that the memory layout of the structure will be well-packed.
permutation = list(range(len(newfields)))
# Note that the results of `pod_size` ensure that non-POD fields
# sort before POD ones.
def size(idx):
return pod_size(newfields[idx].ipdltype)
permutation.sort(key=size, reverse=
True)
sd.fields = newfields
sd.packed_field_order = permutation
StructDecl.upgrade(sd)
def visitUnionDecl(self, ud):
ud.components = [_UnionMember(ctype, ud)
for ctype
in ud.decl.type.components]
UnionDecl.upgrade(ud)
def visitDecl(self, decl):
return _HybridDecl(decl.type, decl.progname, decl.attributes)
def visitMessageDecl(self, md):
md.namespace = self.protocolName
md.params = [param.accept(self)
for param
in md.inParams]
md.returns = [ret.accept(self)
for ret
in md.outParams]
MessageDecl.upgrade(md)
# -----------------------------------------------------------------------------
def msgenums(protocol, pretty=
False):
msgenum = TypeEnum(
"MessageType")
msgstart = _messageStartName(protocol.decl.type) +
" << 16"
msgenum.addId(protocol.name +
"Start", msgstart)
for md
in protocol.messageDecls:
msgenum.addId(md.prettyMsgName()
if pretty
else md.msgId())
if md.hasReply():
msgenum.addId(md.prettyReplyName()
if pretty
else md.replyId())
msgenum.addId(protocol.name +
"End")
return msgenum
class _GenerateProtocolCode(ipdl.ast.Visitor):
"""Creates code common to both the parent and child actors."""
def __init__(self):
self.protocol =
None # protocol we're generating a class for
self.hdrfile =
None # what will become Protocol.h
self.cppfile =
None # what will become Protocol.cpp
self.cppIncludeHeaders = []
self.structUnionDefns = []
self.funcDefns = []
def lower(self, tu, cxxHeaderFile, cxxFile, segmentcapacitydict):
self.protocol = tu.protocol
self.hdrfile = cxxHeaderFile
self.cppfile = cxxFile
self.segmentcapacitydict = segmentcapacitydict
tu.accept(self)
def visitTranslationUnit(self, tu):
hf = self.hdrfile
hf.addthing(_DISCLAIMER)
hf.addthings(_includeGuardStart(hf))
hf.addthing(Whitespace.NL)
for inc
in builtinHeaderIncludes:
self.visitBuiltinCxxInclude(inc)
# Compute the set of includes we need for declared structure/union
# classes for this protocol.
typesToIncludes = {}
for using
in tu.using:
typestr = str(using.type)
if typestr
not in typesToIncludes:
typesToIncludes[typestr] = using.header
else:
assert typesToIncludes[typestr] == using.header
aggregateTypeIncludes = set()
for su
in tu.structsAndUnions:
typedeps = _ComputeTypeDeps(su.decl.type, typesToIncludes)
if isinstance(su, ipdl.ast.StructDecl):
aggregateTypeIncludes.add(
"mozilla/ipc/IPDLStructMember.h")
for f
in su.fields:
f.ipdltype.accept(typedeps)
elif isinstance(su, ipdl.ast.UnionDecl):
for c
in su.components:
c.ipdltype.accept(typedeps)
aggregateTypeIncludes.update(typedeps.includeHeaders)
if len(aggregateTypeIncludes) != 0:
hf.addthing(Whitespace.NL)
hf.addthings([Whitespace(
"// Headers for typedefs"), Whitespace.NL])
for headername
in sorted(iter(aggregateTypeIncludes)):
hf.addthing(CppDirective(
"include",
'"' + headername +
'"'))
# Manually run Visitor.visitTranslationUnit. For dependency resolution
# we need to handle structs and unions separately.
for cxxInc
in tu.cxxIncludes:
cxxInc.accept(self)
for inc
in tu.includes:
inc.accept(self)
self.generateStructsAndUnions(tu)
for using
in tu.builtinUsing:
using.accept(self)
for using
in tu.using:
using.accept(self)
if tu.protocol:
tu.protocol.accept(self)
if tu.filetype ==
"header":
self.cppIncludeHeaders.append(_ipdlhHeaderName(tu) +
".h")
hf.addthing(Whitespace.NL)
hf.addthings(_includeGuardEnd(hf))
cf = self.cppfile
cf.addthings(
(
[_DISCLAIMER, Whitespace.NL]
+ [
CppDirective(
"include",
'"' + h +
'"')
for h
in self.cppIncludeHeaders
]
+ [Whitespace.NL]
+ [
CppDirective(
"include",
'"%s"' % filename)
for filename
in ipdl.builtin.CppIncludes
]
+ [Whitespace.NL]
)
)
if self.protocol:
# construct the namespace into which we'll stick all our defns
ns = Namespace(self.protocol.name)
cf.addthing(_putInNamespaces(ns, self.protocol.namespaces))
ns.addstmts(([Whitespace.NL] + self.funcDefns + [Whitespace.NL]))
cf.addthings(self.structUnionDefns)
def visitBuiltinCxxInclude(self, inc):
self.hdrfile.addthing(CppDirective(
"include",
'"' + inc.file +
'"'))
def visitCxxInclude(self, inc):
self.cppIncludeHeaders.append(inc.file)
def visitInclude(self, inc):
if inc.tu.filetype ==
"header":
self.hdrfile.addthing(
CppDirective(
"include",
'"' + _ipdlhHeaderName(inc.tu) +
'.h"')
)
# Inherit cpp includes defined by imported header files, as they may
# be required to serialize an imported `using` type.
for cxxinc
in inc.tu.cxxIncludes:
cxxinc.accept(self)
else:
self.cppIncludeHeaders += [
_protocolHeaderName(inc.tu.protocol,
"parent") +
".h",
_protocolHeaderName(inc.tu.protocol,
"child") +
".h",
]
def generateStructsAndUnions(self, tu):
"""Generate the definitions for all structs and unions. This will
re-order the declarations
if needed
in the C++ code such that
dependencies have already been defined.
"""
decls = OrderedDict()
for su
in tu.structsAndUnions:
if isinstance(su, StructDecl):
which =
"struct"
forwarddecls, fulldecltypes, cls = _generateCxxStruct(su)
traitsdecl, traitsdefns = _ParamTraits.structPickling(su.decl.type)
else:
assert isinstance(su, UnionDecl)
which =
"union"
forwarddecls, fulldecltypes, cls = _generateCxxUnion(su)
traitsdecl, traitsdefns = _ParamTraits.unionPickling(su.decl.type)
clsdecl, methoddefns = _splitClassDeclDefn(cls)
# Store the declarations in the decls map so we can emit in
# dependency order.
decls[su.decl.type] = (
fulldecltypes,
[Whitespace.NL]
+ forwarddecls
+ [
Whitespace(
"""
//-----------------------------------------------------------------------------
// Declaration of the IPDL type |%s %s|
//
"""
% (which, su.name)
),
_putInNamespaces(clsdecl, su.namespaces),
]
+ [Whitespace.NL, traitsdecl],
)
self.structUnionDefns.extend(
[
Whitespace(
"""
//-----------------------------------------------------------------------------
// Method definitions
for the IPDL type |%s %s|
//
"""
% (which, su.name)
),
_putInNamespaces(methoddefns, su.namespaces),
Whitespace.NL,
traitsdefns,
]
)
# Generate the declarations structs in dependency order.
def gen_struct(deps, defn):
for dep
in deps:
if dep
in decls:
d, t = decls[dep]
del decls[dep]
gen_struct(d, t)
self.hdrfile.addthings(defn)
while len(decls) > 0:
_, (d, t) = decls.popitem(
False)
gen_struct(d, t)
def visitProtocol(self, p):
self.cppIncludeHeaders.append(_protocolHeaderName(self.protocol,
"") +
".h")
self.cppIncludeHeaders.append(
_protocolHeaderName(self.protocol,
"Parent") +
".h"
)
self.cppIncludeHeaders.append(
_protocolHeaderName(self.protocol,
"Child") +
".h"
)
# Forward declare our own actors.
self.hdrfile.addthings(
[
Whitespace.NL,
_makeForwardDeclForActor(p.decl.type,
"Parent"),
_makeForwardDeclForActor(p.decl.type,
"Child"),
]
)
self.hdrfile.addthing(
Whitespace(
"""
//-----------------------------------------------------------------------------
// Code common to %sChild
and %sParent
//
"""
% (p.name, p.name)
)
)
# construct the namespace into which we'll stick all our decls
ns = Namespace(self.protocol.name)
self.hdrfile.addthing(_putInNamespaces(ns, p.namespaces))
ns.addstmt(Whitespace.NL)
for func
in self.genEndpointFuncs():
edecl, edefn = _splitFuncDeclDefn(func)
ns.addstmts([edecl, Whitespace.NL])
self.funcDefns.append(edefn)
# spit out message type enum and classes
msgenum = msgenums(self.protocol)
ns.addstmts([StmtDecl(Decl(msgenum,
"")), Whitespace.NL])
for md
in p.messageDecls:
decls = []
# Look up the segment capacity used for serializing this
# message. If the capacity is not specified, use '0' for
# the default capacity (defined in ipc_message.cc)
name =
"%s::%s" % (md.namespace, md.decl.progname)
segmentcapacity = self.segmentcapacitydict.get(name, 0)
mfDecl, mfDefn = _splitFuncDeclDefn(
_generateMessageConstructor(md, segmentcapacity, p, forReply=
False)
)
decls.append(mfDecl)
self.funcDefns.append(mfDefn)
if md.hasReply():
rfDecl, rfDefn = _splitFuncDeclDefn(
_generateMessageConstructor(md, 0, p, forReply=
True)
)
decls.append(rfDecl)
self.funcDefns.append(rfDefn)
decls.append(Whitespace.NL)
ns.addstmts(decls)
ns.addstmts([Whitespace.NL, Whitespace.NL])
# Generate code for PFoo::CreateEndpoints.
def genEndpointFuncs(self):
p = self.protocol.decl.type
tparent = _cxxBareType(ActorType(p),
"Parent", fq=
True)
tchild = _cxxBareType(ActorType(p),
"Child", fq=
True)
def mkOverload(includepids):
params = []
if includepids:
params = [
Decl(Type(
"mozilla::ipc::EndpointProcInfo"),
"aParentDestInfo"),
Decl(Type(
"mozilla::ipc::EndpointProcInfo"),
"aChildDestInfo"),
]
params += [
Decl(
Type(
"mozilla::ipc::Endpoint<" + tparent.name +
">", ptr=
True),
"aParent",
),
Decl(
Type(
"mozilla::ipc::Endpoint<" + tchild.name +
">", ptr=
True),
"aChild",
),
]
openfunc = MethodDefn(
MethodDecl(
"CreateEndpoints", params=params, ret=Type.NSRESULT)
)
openfunc.addcode(
"""
return mozilla::ipc::CreateEndpoints(
mozilla::ipc::PrivateIPDLInterface(),
$,{args});
""",
args=[ExprVar(d.name)
for d
in params],
)
return openfunc
funcs = [mkOverload(
True)]
if not p.hasOtherPid():
funcs.append(mkOverload(
False))
return funcs
# --------------------------------------------------
cppPriorityList = list(
map(
lambda src: src.upper() +
"_PRIORITY", ipdl.ast.priorityList)
)
def _generateMessageConstructor(md, segmentSize, protocol, forReply=
False):
if forReply:
clsname = md.replyCtorFunc()
msgid = md.replyId()
replyEnum =
"REPLY"
prioEnum = cppPriorityList[md.decl.type.replyPrio]
else:
clsname = md.msgCtorFunc()
msgid = md.msgId()
replyEnum =
"NOT_REPLY"
prioEnum = cppPriorityList[md.decl.type.prio]
nested = md.decl.type.nested
compress = md.decl.type.compress
lazySend = md.decl.type.lazySend
routingId = ExprVar(
"routingId")
func = FunctionDefn(
FunctionDecl(
clsname,
params=[Decl(Type(
"int32_t"), routingId.name)],
ret=Type(
"mozilla::UniquePtr<IPC::Message>"),
)
)
if not compress:
compression =
"COMPRESSION_NONE"
elif compress.value ==
"all":
compression =
"COMPRESSION_ALL"
else:
assert compress.value
is None
compression =
"COMPRESSION_ENABLED"
if lazySend:
lazySendEnum =
"LAZY_SEND"
else:
lazySendEnum =
"EAGER_SEND"
if nested == ipdl.ast.NOT_NESTED:
nestedEnum =
"NOT_NESTED"
elif nested == ipdl.ast.INSIDE_SYNC_NESTED:
nestedEnum =
"NESTED_INSIDE_SYNC"
else:
assert nested == ipdl.ast.INSIDE_CPOW_NESTED
nestedEnum =
"NESTED_INSIDE_CPOW"
if md.decl.type.isSync():
syncEnum =
"SYNC"
else:
syncEnum =
"ASYNC"
if md.decl.type.isCtor():
ctorEnum =
"CONSTRUCTOR"
else:
ctorEnum =
"NOT_CONSTRUCTOR"
def messageEnum(valname):
return ExprVar(
"IPC::Message::" + valname)
flags = ExprCall(
ExprVar(
"IPC::Message::HeaderFlags"),
args=[
messageEnum(nestedEnum),
messageEnum(prioEnum),
messageEnum(compression),
messageEnum(lazySendEnum),
messageEnum(ctorEnum),
messageEnum(syncEnum),
messageEnum(replyEnum),
],
)
segmentSize = int(segmentSize)
if not segmentSize:
segmentSize = 0
func.addstmt(
StmtReturn(
ExprCall(
ExprVar(
"IPC::Message::IPDLMessage"),
args=[
routingId,
ExprVar(msgid),
ExprLiteral.Int(int(segmentSize)),
flags,
],
)
)
)
return func
# --------------------------------------------------
class _ParamTraits:
var = ExprVar(
"aVar")
writervar = ExprVar(
"aWriter")
readervar = ExprVar(
"aReader")
@classmethod
def ifsideis(cls, rdrwtr, side, then, els=
None):
ifstmt = StmtIf(
ExprBinary(
_cxxSide(side),
"==",
ExprCode(
"${rdrwtr}->GetActor()->GetSide()", rdrwtr=rdrwtr),
)
)
ifstmt.addifstmt(then)
if els
is not None:
ifstmt.addelsestmt(els)
return ifstmt
@classmethod
def fatalError(cls, rdrwtr, reason):
return StmtCode(
"${rdrwtr}->FatalError(${reason});",
rdrwtr=rdrwtr,
reason=ExprLiteral.String(reason),
)
@classmethod
def writeSentinel(cls, writervar, sentinelKey):
return [
Whitespace(
"// Sentinel = " + repr(sentinelKey) +
"\n", indent=
True),
StmtExpr(
ExprCall(
ExprSelect(writervar,
"->",
"WriteSentinel"),
args=[ExprLiteral.Int(hashfunc(sentinelKey))],
)
),
]
@classmethod
def readSentinel(cls, readervar, sentinelKey, sentinelFail):
# Read the sentinel
read = ExprCall(
ExprSelect(readervar,
"->",
"ReadSentinel"),
args=[ExprLiteral.Int(hashfunc(sentinelKey))],
)
ifsentinel = StmtIf(ExprNot(read))
ifsentinel.addifstmts(sentinelFail)
return [
Whitespace(
"// Sentinel = " + repr(sentinelKey) +
"\n", indent=
True),
ifsentinel,
]
@classmethod
def write(cls, var, writervar, ipdltype=
None):
if ipdltype
and _cxxTypeNeedsMoveForSend(ipdltype):
var = ExprMove(var)
return ExprCall(ExprVar(
"IPC::WriteParam"), args=[writervar, var])
@classmethod
def checkedWrite(cls, ipdltype, var, writervar, sentinelKey):
assert sentinelKey
block = Block()
block.addstmts(
[
StmtExpr(cls.write(var, writervar, ipdltype)),
]
)
block.addstmts(cls.writeSentinel(writervar, sentinelKey))
return block
@classmethod
def bulkSentinelKey(cls, fields):
return " | ".join(f.basename
for f
in fields)
@classmethod
def checkedBulkWrite(cls, var, size, fields):
block = Block()
first = fields[0]
block.addstmts(
[
StmtExpr(
ExprCall(
ExprSelect(cls.writervar,
"->",
"WriteBytes"),
args=[
ExprAddrOf(
ExprCall(first.getMethod(thisexpr=var, sel=
"."))
),
ExprLiteral.Int(size * len(fields)),
],
)
)
]
)
block.addstmts(cls.writeSentinel(cls.writervar, cls.bulkSentinelKey(fields)))
return block
@classmethod
def checkedBulkRead(cls, var, size, fields):
block = Block()
first = fields[0]
readbytes = ExprCall(
ExprSelect(cls.readervar,
"->",
"ReadBytesInto"),
args=[
ExprAddrOf(ExprCall(first.getMethod(thisexpr=var, sel=
"->"))),
ExprLiteral.Int(size * len(fields)),
],
)
ifbad = StmtIf(ExprNot(readbytes))
errmsg =
"Error bulk reading fields from %s" % first.ipdltype.name()
ifbad.addifstmts(
[cls.fatalError(cls.readervar, errmsg), StmtReturn(readResultError())]
)
block.addstmt(ifbad)
block.addstmts(
cls.readSentinel(
cls.readervar,
cls.bulkSentinelKey(fields),
errfnSentinel(readResultError())(errmsg),
)
)
return block
@classmethod
def checkedRead(
cls,
ipdltype,
cxxtype,
var,
readervar,
errfn,
paramtype,
sentinelKey,
errfnSentinel,
):
assert isinstance(var, ExprVar)
if not isinstance(paramtype, list):
paramtype = [
"Error deserializing " + paramtype]
block = Block()
# Read the data
block.addcode(
"""
auto ${maybevar} = IPC::ReadParam<${ty}>(${reader});
if (!${maybevar}) {
$*{errfn}
}
auto& ${var} = *${maybevar};
""",
maybevar=ExprVar(
"maybe__" + var.name),
ty=cxxtype,
reader=readervar,
errfn=errfn(*paramtype),
var=var,
)
block.addstmts(
cls.readSentinel(readervar, sentinelKey, errfnSentinel(*paramtype))
)
return block
# Helper wrapper for checkedRead for use within _ParamTraits
@classmethod
def _checkedRead(cls, ipdltype, cxxtype, var, sentinelKey, what):
def errfn(msg):
return [cls.fatalError(cls.readervar, msg), StmtReturn(readResultError())]
return cls.checkedRead(
ipdltype,
cxxtype,
var,
cls.readervar,
errfn=errfn,
paramtype=what,
sentinelKey=sentinelKey,
errfnSentinel=errfnSentinel(readResultError()),
)
@classmethod
def generateDecl(cls, fortype, write, read, needsmove=
False):
# ParamTraits impls are selected ignoring constness, and references.
pt =
Class(
"ParamTraits",
specializes=Type(
fortype.name, T=fortype.T, inner=fortype.inner, ptr=fortype.ptr
),
struct=
True,
)
# typedef T paramType;
pt.addstmt(Typedef(fortype,
"paramType"))
# static void Write(Message*, const T&);
if needsmove:
intype = Type(
"paramType", rvalref=
True)
else:
intype = Type(
"paramType", ref=
True, const=
True)
writemthd = MethodDefn(
MethodDecl(
"Write",
params=[
Decl(Type(
"IPC::MessageWriter", ptr=
True), cls.writervar.name),
Decl(intype, cls.var.name),
],
methodspec=MethodSpec.STATIC,
)
)
writemthd.addstmts(write)
pt.addstmt(writemthd)
# static ReadResult<T> Read(MessageReader*);
readmthd = MethodDefn(
MethodDecl(
"Read",
params=[
Decl(Type(
"IPC::MessageReader", ptr=
True), cls.readervar.name),
],
ret=Type(
"IPC::ReadResult<paramType>"),
methodspec=MethodSpec.STATIC,
)
)
readmthd.addstmts(read)
pt.addstmt(readmthd)
# Split the class into declaration and definition
clsdecl, methoddefns = _splitClassDeclDefn(pt)
namespaces = [Namespace(
"IPC")]
clsns = _putInNamespaces(clsdecl, namespaces)
defns = _putInNamespaces(methoddefns, namespaces)
return clsns, defns
@classmethod
def actorPickling(cls, actortype, side):
"""Generates pickling for IPDL actors. This is a |nullable| deserializer.
Write
and read callers will perform nullability validation.
"""
cxxtype = _cxxBareType(actortype, side, fq=
True)
write = StmtCode(
"""
MOZ_RELEASE_ASSERT(
${writervar}->GetActor(),
"Cannot serialize managed actors without an actor");
int32_t id;
if (!${var}) {
id = 0; // kNullActorId
}
else {
id = ${var}->Id();
if (id == 1) { // kFreedActorId
${var}->FatalError(
"Actor has been |delete|d");
}
MOZ_RELEASE_ASSERT(
${writervar}->GetActor()->GetIPCChannel() == ${var}->GetIPCChannel(),
"Actor must be from the same channel as the"
" actor it's being sent over");
MOZ_RELEASE_ASSERT(
${var}->CanSend(),
"Actor must still be open when sending");
}
${write};
""",
var=cls.var,
writervar=cls.writervar,
write=cls.write(ExprVar(
"id"), cls.writervar),
)
# bool Read(..) impl
read = StmtCode(
"""
MOZ_RELEASE_ASSERT(
${readervar}->GetActor(),
"Cannot deserialize managed actors without an actor");
mozilla::Maybe<mozilla::ipc::IProtocol*> actor = ${readervar}->GetActor()
->ReadActor(${readervar},
true, ${actortype}, ${protocolid});
if (actor.isSome()) {
return static_cast<${cxxtype}>(actor.ref());
}
return {};
""",
readervar=cls.readervar,
actortype=ExprLiteral.String(actortype.name()),
protocolid=_protocolId(actortype),
cxxtype=cxxtype,
)
return cls.generateDecl(cxxtype, [write], [read])
@classmethod
def structPickling(cls, structtype):
sd = structtype._ast
# NOTE: Not using _cxxBareType here as we don't have a side
cxxtype = Type(structtype.fullname())
write = []
read = []
# First serialize/deserialize all non-pod data in IPDL order. These need
# to be read/written first because they'll be used to invoke the IPDL
# struct's constructor.
ctorargs = []
for f
in sd.fields_ipdl_order():
if pod_size(f.ipdltype) == pod_size_sentinel:
write.append(
cls.checkedWrite(
f.ipdltype,
ExprCall(f.getMethod(thisexpr=cls.var, sel=
".")),
cls.writervar,
sentinelKey=f.basename,
)
)
read.append(
cls._checkedRead(
f.ipdltype,
f.bareType(fq=
True),
f.argVar(),
f.basename,
"'"
+ f.getMethod().name
+
"' "
+
"("
+ f.ipdltype.name()
+
") member of "
+
"'"
+ structtype.name()
+
"'",
)
)
if _cxxTypeCanMove(f.ipdltype):
ctorargs.append(ExprMove(f.argVar()))
else:
ctorargs.append(f.argVar())
else:
# We're going to bulk-read in this value later, so we'll just
# zero-initialize it for now.
ctorargs.append(ExprCode(
"${type}{0}", type=f.bareType(fq=
True)))
resultvar = ExprVar(
"result__")
read.append(
StmtDecl(
Decl(_cxxReadResultType(Type(
"paramType")), resultvar.name),
initargs=[ExprVar(
"std::in_place")] + ctorargs,
)
)
# After non-pod data, bulk read/write pod data in member order. This has
# to be done after the result has been constructed, so that we have
# somewhere to read into.
for size, fields
in itertools.groupby(
sd.fields_member_order(),
lambda f: pod_size(f.ipdltype)
):
if size != pod_size_sentinel:
fields = list(fields)
write.append(cls.checkedBulkWrite(cls.var, size, fields))
read.append(cls.checkedBulkRead(resultvar, size, fields))
read.append(StmtReturn(resultvar))
return cls.generateDecl(
cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(structtype)
)
@classmethod
def unionPickling(cls, uniontype):
# NOTE: Not using _cxxBareType here as we don't have a side
cxxtype = Type(uniontype.fullname())
ud = uniontype._ast
# Use typedef to set up an alias so it's easier to reference the struct type.
alias =
"union__"
typevar = ExprVar(
"type")
prelude = [
Typedef(cxxtype, alias),
]
writeswitch = StmtSwitch(typevar)
write = prelude + [
StmtDecl(Decl(Type.INT, typevar.name), init=ud.callType(cls.var)),
cls.checkedWrite(
None, typevar, cls.writervar, sentinelKey=uniontype.name()
),
Whitespace.NL,
writeswitch,
]
readswitch = StmtSwitch(typevar)
read = prelude + [
cls._checkedRead(
None,
Type.INT,
typevar,
uniontype.name(),
"type of union " + uniontype.name(),
),
Whitespace.NL,
readswitch,
]
for c
in ud.components:
caselabel = CaseLabel(alias +
"::" + c.enum())
origenum = c.enum()
writecase = StmtBlock()
wstmt = cls.checkedWrite(
c.ipdltype,
ExprCall(ExprSelect(cls.var,
".", c.getTypeName())),
cls.writervar,
sentinelKey=c.enum(),
)
writecase.addstmts([wstmt, StmtReturn()])
writeswitch.addcase(caselabel, writecase)
readcase = StmtBlock()
tmpvar = ExprVar(
"tmp")
readcase.addstmts(
[
cls._checkedRead(
c.ipdltype,
c.bareType(fq=
True),
tmpvar,
origenum,
"variant " + origenum +
" of union " + uniontype.name(),
),
StmtReturn(ExprMove(tmpvar)),
]
)
readswitch.addcase(caselabel, readcase)
# Add the error default case
writeswitch.addcase(
DefaultLabel(),
StmtBlock(
[
cls.fatalError(
cls.writervar,
"unknown variant of union " + uniontype.name()
),
StmtReturn(),
]
),
)
readswitch.addcase(
DefaultLabel(),
StmtBlock(
[
cls.fatalError(
cls.readervar,
"unknown variant of union " + uniontype.name()
),
StmtReturn(readResultError()),
]
),
)
return cls.generateDecl(
cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(uniontype)
)
# --------------------------------------------------
class _ComputeTypeDeps(TypeVisitor):
"""Pass that gathers the C++ types that a particular IPDL type
(recursively) depends on. There are three kinds of dependencies: (i)
types that need forward declaration; (ii) types that need a |using|
stmt; (iii) IPDL structs
or unions which must be fully declared
before this struct. Some types generate multiple kinds.
"""
def __init__(self, fortype, typesToIncludes=
None):
ipdl.type.TypeVisitor.__init__(self)
self.usingTypedefs = []
self.forwardDeclStmts = []
self.fullDeclTypes = []
self.includeHeaders = set()
self.fortype = fortype
self.typesToIncludes = typesToIncludes
def maybeTypedef(self, fqname, name, templateargs=[]):
assert fqname.startswith(
"::")
if fqname != name:
self.usingTypedefs.append(Typedef(Type(fqname), name, templateargs))
if self.typesToIncludes
is not None and fqname
in self.typesToIncludes:
self.includeHeaders.add(self.typesToIncludes[fqname])
def visitImportedCxxType(self, t):
if t
in self.visited:
return
self.visited.add(t)
self.maybeTypedef(t.fullname(), t.name())
def visitActorType(self, t):
if t
in self.visited:
return
self.visited.add(t)
fqname, name = t.fullname(), t.name()
self.includeHeaders.add(
"mozilla/ipc/SideVariant.h")
self.maybeTypedef(_actorName(fqname,
"Parent"), _actorName(name,
"Parent"))
self.maybeTypedef(_actorName(fqname,
"Child"), _actorName(name,
"Child"))
self.forwardDeclStmts.extend(
[
_makeForwardDeclForActor(t.protocol,
"parent"),
Whitespace.NL,
_makeForwardDeclForActor(t.protocol,
"child"),
Whitespace.NL,
]
)
def visitStructOrUnionType(self, su, defaultVisit):
if su
in self.visited
or su == self.fortype:
return
self.visited.add(su)
self.maybeTypedef(su.fullname(), su.name())
# Mutually recursive fields in unions are behind indirection, so we only
# need a forward decl, and don't need a full type declaration.
if isinstance(self.fortype, UnionType)
and self.fortype.mutuallyRecursiveWith(
su
):
self.forwardDeclStmts.append(_makeForwardDecl(su))
else:
self.fullDeclTypes.append(su)
return defaultVisit(self, su)
def visitStructType(self, t):
return self.visitStructOrUnionType(t, TypeVisitor.visitStructType)
def visitUnionType(self, t):
return self.visitStructOrUnionType(t, TypeVisitor.visitUnionType)
def visitArrayType(self, t):
return TypeVisitor.visitArrayType(self, t)
def visitMaybeType(self, m):
return TypeVisitor.visitMaybeType(self, m)
def visitShmemType(self, s):
if s
in self.visited:
return
self.visited.add(s)
self.maybeTypedef(
"::mozilla::ipc::Shmem",
"Shmem")
def visitByteBufType(self, s):
if s
in self.visited:
return
self.visited.add(s)
self.maybeTypedef(
"::mozilla::ipc::ByteBuf",
"ByteBuf")
def visitFDType(self, s):
if s
in self.visited:
return
self.visited.add(s)
self.maybeTypedef(
"::mozilla::ipc::FileDescriptor",
"FileDescriptor")
def visitEndpointType(self, s):
if s
in self.visited:
return
self.visited.add(s)
self.includeHeaders.add(
"mozilla/ipc/Endpoint.h")
self.maybeTypedef(
"::mozilla::ipc::Endpoint",
"Endpoint", [
"FooSide"])
self.visitActorType(s.actor)
def visitManagedEndpointType(self, s):
if s
in self.visited:
return
self.visited.add(s)
self.includeHeaders.add(
"mozilla/ipc/Endpoint.h")
self.maybeTypedef(
"::mozilla::ipc::ManagedEndpoint",
"ManagedEndpoint", [
"FooSide"]
)
self.visitActorType(s.actor)
def visitUniquePtrType(self, s):
return TypeVisitor.visitUniquePtrType(self, s)
def visitVoidType(self, v):
assert 0
def visitMessageType(self, v):
assert 0
def visitProtocolType(self, v):
assert 0
def _fieldStaticAssertions(sd):
staticasserts = []
for size, fields
in itertools.groupby(
sd.fields_member_order(),
lambda f: pod_size(f.ipdltype)
):
if size == pod_size_sentinel:
continue
fields = list(fields)
if len(fields) == 1:
continue
staticasserts.append(
StmtCode(
"""
static_assert(
(offsetof(${struct}, ${last}) - offsetof(${struct}, ${first})) == ${expected},
"Bad assumptions about field layout!");
""",
struct=sd.name,
first=fields[0].memberVar(),
last=fields[-1].memberVar(),
expected=ExprLiteral.Int(size * (len(fields) - 1)),
)
)
return staticasserts
def _generateCxxStruct(sd):
""" """
# compute all the typedefs and forward decls we need to make
gettypedeps = _ComputeTypeDeps(sd.decl.type)
for f
in sd.fields:
f.ipdltype.accept(gettypedeps)
usingTypedefs = gettypedeps.usingTypedefs
forwarddeclstmts = gettypedeps.forwardDeclStmts
fulldecltypes = gettypedeps.fullDeclTypes
struct =
Class(sd.name, final=
True)
struct.addstmts([Label.PRIVATE] + usingTypedefs + [Whitespace.NL, Label.PUBLIC])
constreftype = Type(sd.name, const=
True, ref=
True)
# Struct()
# We want the default constructor to be declared if it is available, but
# some of our members may not be default-constructible. Silence the
# warning which clang generates in that case.
#
# Members which need value initialization will be handled by wrapping
# the member in a template type when declaring them.
struct.addcode(
"""
#ifdef __clang__
# pragma clang diagnostic push
# if __has_warning("-Wdefaulted-function-deleted")
# pragma clang diagnostic ignored "-Wdefaulted-function-deleted"
# endif
#endif
${name}() = default;
#ifdef __clang__
# pragma clang diagnostic pop
#endif
""",
name=sd.name,
)
# If this is an empty struct (no fields), then the default ctor
# and "create-with-fields" ctors are equivalent.
if len(sd.fields):
assert len(sd.fields) == len(sd.packed_field_order)
# Struct(const field1& _f1, ...)
valctor = ConstructorDefn(
ConstructorDecl(
sd.name,
params=[
Decl(
(
f.forceMoveType()
if _cxxTypeNeedsMoveForData(f.ipdltype)
else f.constRefType()
),
f.argVar().name,
)
for f
in sd.fields_ipdl_order()
],
force_inline=
True,
)
)
valctor.memberinits = []
for f
in sd.fields_member_order():
arg = f.argVar()
if _cxxTypeNeedsMoveForData(f.ipdltype):
arg = ExprMove(arg)
valctor.memberinits.append(ExprMemberInit(f.memberVar(), args=[arg]))
struct.addstmts([valctor, Whitespace.NL])
# If a constructor which moves each argument would be different from the
# `const T&` version, also generate that constructor.
if not all(
_cxxTypeNeedsMoveForData(f.ipdltype)
or not _cxxTypeCanMove(f.ipdltype)
for f
in sd.fields_ipdl_order()
):
# Struct(field1&& _f1, ...)
valmovector = ConstructorDefn(
ConstructorDecl(
sd.name,
params=[
Decl(
(
f.forceMoveType()
if _cxxTypeCanMove(f.ipdltype)
else f.constRefType()
),
f.argVar().name,
)
for f
in sd.fields_ipdl_order()
],
force_inline=
True,
)
)
valmovector.memberinits = []
for f
in sd.fields_member_order():
arg = f.argVar()
if _cxxTypeCanMove(f.ipdltype):
arg = ExprMove(arg)
valmovector.memberinits.append(
ExprMemberInit(f.memberVar(), args=[arg])
)
struct.addstmts([valmovector, Whitespace.NL])
# The default copy, move, and assignment constructors, and the default
# destructor, will do the right thing.
if "Comparable" in sd.attributes:
# bool operator==(const Struct& _o)
ovar = ExprVar(
"_o")
opeqeq = MethodDefn(
MethodDecl(
"operator==",
params=[Decl(constreftype, ovar.name)],
ret=Type.BOOL,
const=
True,
)
)
for f
in sd.fields_ipdl_order():
ifneq = StmtIf(
ExprNot(
ExprBinary(
ExprCall(f.getMethod()),
"==", ExprCall(f.getMethod(ovar))
)
)
)
ifneq.addifstmt(StmtReturn.
FALSE)
opeqeq.addstmt(ifneq)
opeqeq.addstmt(StmtReturn.
TRUE)
struct.addstmts([opeqeq, Whitespace.NL])
# bool operator!=(const Struct& _o)
opneq = MethodDefn(
MethodDecl(
"operator!=",
params=[Decl(constreftype, ovar.name)],
ret=Type.BOOL,
const=
True,
)
)
opneq.addstmt(StmtReturn(ExprNot(ExprCall(ExprVar(
"operator=="), args=[ovar]))))
struct.addstmts([opneq, Whitespace.NL])
# field1& f1()
# const field1& f1() const
for f
in sd.fields_ipdl_order():
get = MethodDefn(
MethodDecl(
f.getMethod().name, params=[], ret=f.refType(), force_inline=
True
)
)
get.addstmt(StmtReturn(f.refExpr()))
getconstdecl = deepcopy(get.decl)
getconstdecl.ret = f.constRefType()
getconstdecl.const =
True
getconst = MethodDefn(getconstdecl)
getconst.addstmt(StmtReturn(f.constRefExpr()))
struct.addstmts([get, getconst, Whitespace.NL])
# private:
struct.addstmt(Label.PRIVATE)
# Static assertions to ensure our assumptions about field layout match
# what the compiler is actually producing. We define this as a member
# function, rather than throwing the assertions in the constructor or
# similar, because we don't want to evaluate the static assertions every
# time the header file containing the structure is included.
staticasserts = _fieldStaticAssertions(sd)
if staticasserts:
method = MethodDefn(
MethodDecl(
"StaticAssertions", params=[], ret=Type.VOID, const=
True)
)
method.addstmts(staticasserts)
struct.addstmts([method])
# members
struct.addstmts(
[
StmtDecl(Decl(_effectiveMemberType(f), f.memberVar().name))
for f
in sd.fields_member_order()
]
)
return forwarddeclstmts, fulldecltypes, struct
def _effectiveMemberType(f):
effective_type = f.bareType()
# Structs must be copyable for backwards compatibility reasons, so we use
# CopyableTArray<T> as their member type for arrays. This is not exposed
# in the method signatures, these keep using nsTArray<T>, which is a base
# class of CopyableTArray<T>.
if effective_type.name ==
"nsTArray":
effective_type.name =
"CopyableTArray"
return Type(
"::mozilla::ipc::IPDLStructMember", T=[effective_type])
# --------------------------------------------------
def _generateCxxUnion(ud):
# This Union class basically consists of a type (enum) and a
# union for storage. The union can contain POD and non-POD
# types. Each type needs a copy/move ctor, assignment operators,
# and dtor.
#
# Rather than templating this class and only providing
# specializations for the types we support, which is slightly
# "unsafe" in that C++ code can add additional specializations
# without the IPDL compiler's knowledge, we instead explicitly
# implement non-templated methods for each supported type.
#
# The one complication that arises is that C++, for arcane
# reasons, does not allow the placement destructor of a
# builtin type, like int, to be directly invoked. So we need
# to hack around this by internally typedef'ing all
# constituent types. Sigh.
#
# So, for each type, this "Union" class needs:
# (private)
# - entry in the type enum
# - entry in the storage union
# - [type]ptr() method to get a type* from the underlying union
# - same as above to get a const type*
# - typedef to hack around placement delete limitations
# (public)
# - placement delete case for dtor
# - copy ctor
# - move ctor
# - case in generic copy ctor
# - copy operator= impl
# - move operator= impl
# - case in generic operator=
# - operator [type&]
# - operator [const type&] const
# - [type&] get_[type]()
# - [const type&] get_[type]() const
#
cls =
Class(ud.name, final=
True)
# const Union&, i.e., Union type with inparam semantics
inClsType = Type(ud.name, const=
True, ref=
True)
refClsType = Type(ud.name, ref=
True)
rvalueRefClsType = Type(ud.name, rvalref=
True)
typetype = Type(
"Type")
mtypevar = ExprVar(
"mType")
maybedtorvar = ExprVar(
"MaybeDestroy")
assertsanityvar = ExprVar(
"AssertSanity")
tnonevar = ExprVar(
"T__None")
tlastvar = ExprVar(
"T__Last")
def callAssertSanity(uvar=
None, expectTypeVar=
None):
func = assertsanityvar
args = []
if uvar
is not None:
func = ExprSelect(uvar,
".", assertsanityvar.name)
if expectTypeVar
is not None:
args.append(expectTypeVar)
return ExprCall(func, args=args)
def maybeDestroy():
return StmtExpr(ExprCall(maybedtorvar))
# compute all the typedefs and forward decls we need to make
gettypedeps = _ComputeTypeDeps(ud.decl.type)
for c
in ud.components:
c.ipdltype.accept(gettypedeps)
usingTypedefs = gettypedeps.usingTypedefs
forwarddeclstmts = gettypedeps.forwardDeclStmts
fulldecltypes = gettypedeps.fullDeclTypes
# the |Type| enum, used to switch on the discunion's real type
cls.addstmt(Label.PUBLIC)
typeenum = TypeEnum(typetype.name)
typeenum.addId(tnonevar.name, 0)
firstid = ud.components[0].enum()
typeenum.addId(firstid, 1)
for c
in ud.components[1:]:
typeenum.addId(c.enum())
typeenum.addId(tlastvar.name, ud.components[-1].enum())
cls.addstmts([StmtDecl(Decl(typeenum,
"")), Whitespace.NL])
cls.addstmt(Label.PRIVATE)
cls.addstmts(
usingTypedefs
# hacky typedef's that allow placement dtors of builtins
+ [Typedef(c.internalType(), c.typedef())
for c
in ud.components]
)
cls.addstmt(Whitespace.NL)
# the C++ union the discunion use for storage
valuetype = TypeUnion()
for c
in ud.components:
valuetype.addComponent(c.unionType(), c.name)
# for each constituent type T, add private accessors that
# return a pointer to the Value union storage casted to |T*|
# and |const T*|
for c
in ud.components:
getptr = MethodDefn(
MethodDecl(
c.getPtrName(), params=[], ret=c.ptrToInternalType(), force_inline=
True
)
)
getptr.addstmt(StmtReturn(c.ptrToSelfExpr()))
getptrconst = MethodDefn(
MethodDecl(
c.getConstPtrName(),
params=[],
ret=c.constPtrToType(),
const=
True,
force_inline=
True,
)
)
getptrconst.addstmt(StmtReturn(c.constptrToSelfExpr()))
cls.addstmts([getptr, getptrconst])
cls.addstmt(Whitespace.NL)
# add a helper method that invokes the placement dtor on the
# current underlying value, only if |aNewType| is different
# than the current type, and returns true if the underlying
# value needs to be re-constructed
maybedtor = MethodDefn(MethodDecl(maybedtorvar.name, ret=Type.VOID))
# wasn't /actually/ dtor'd, but it needs to be re-constructed
ifnone = StmtIf(ExprBinary(mtypevar,
"==", tnonevar))
ifnone.addifstmt(StmtReturn())
# need to destroy. switch on underlying type
dtorswitch = StmtSwitch(mtypevar)
for c
in ud.components:
dtorswitch.addcase(
CaseLabel(c.enum()), StmtBlock([StmtExpr(c.callDtor()), StmtBreak()])
)
dtorswitch.addcase(
DefaultLabel(), StmtBlock([_logicError(
"not reached"), StmtBreak()])
)
maybedtor.addstmts([ifnone, dtorswitch])
cls.addstmts([maybedtor, Whitespace.NL])
# add helper methods that ensure the discunion has a
# valid type
sanity = MethodDefn(
MethodDecl(assertsanityvar.name, ret=Type.VOID, const=
True, force_inline=
True)
)
sanity.addstmts(
[
_abortIfFalse(ExprBinary(tnonevar,
"<=", mtypevar),
"invalid type tag"),
_abortIfFalse(ExprBinary(mtypevar,
"<=", tlastvar),
"invalid type tag"),
]
)
cls.addstmt(sanity)
atypevar = ExprVar(
"aType")
sanity2 = MethodDefn(
MethodDecl(
assertsanityvar.name,
params=[Decl(typetype, atypevar.name)],
ret=Type.VOID,
const=
True,
force_inline=
True,
)
)
sanity2.addstmts(
[
StmtExpr(ExprCall(assertsanityvar)),
_abortIfFalse(ExprBinary(mtypevar,
"==", atypevar),
"unexpected type tag"),
]
)
cls.addstmts([sanity2, Whitespace.NL])
# ---- begin public methods -----
# Union() default ctor
cls.addstmts(
[
Label.PUBLIC,
ConstructorDefn(
ConstructorDecl(ud.name, force_inline=
True),
memberinits=[ExprMemberInit(mtypevar, [tnonevar])],
),
Whitespace.NL,
]
)
# Union(const T&) copy & Union(T&&) move ctors
othervar = ExprVar(
"aOther")
for c
in ud.components:
if not _cxxTypeNeedsMoveForData(c.ipdltype):
copyctor = ConstructorDefn(
ConstructorDecl(ud.name, params=[Decl(c.constRefType(), othervar.name)])
)
copyctor.addstmts(
[
StmtExpr(c.callCtor(othervar)),
StmtExpr(ExprAssn(mtypevar, c.enumvar())),
]
)
cls.addstmts([copyctor, Whitespace.NL])
if not _cxxTypeCanMove(c.ipdltype):
continue
movector = ConstructorDefn(
ConstructorDecl(ud.name, params=[Decl(c.forceMoveType(), othervar.name)])
)
movector.addstmts(
[
StmtExpr(c.callCtor(ExprMove(othervar))),
StmtExpr(ExprAssn(mtypevar, c.enumvar())),
]
)
cls.addstmts([movector, Whitespace.NL])
unionNeedsMove = any(_cxxTypeNeedsMoveForData(c.ipdltype)
for c
in ud.components)
# Union(const Union&) copy ctor
if not unionNeedsMove:
copyctor = ConstructorDefn(
ConstructorDecl(ud.name, params=[Decl(inClsType, othervar.name)])
)
othertype = ud.callType(othervar)
copyswitch = StmtSwitch(othertype)
for c
in ud.components:
copyswitch.addcase(
CaseLabel(c.enum()),
StmtBlock(
[
StmtExpr(
c.callCtor(
ExprCall(
ExprSelect(othervar,
".", c.getConstTypeName())
)
)
),
StmtBreak(),
]
),
)
copyswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()]))
copyswitch.addcase(
DefaultLabel(), StmtBlock([_logicError(
"unreached"), StmtReturn()])
)
copyctor.addstmts(
[
StmtExpr(callAssertSanity(uvar=othervar)),
copyswitch,
StmtExpr(ExprAssn(mtypevar, othertype)),
]
)
cls.addstmts([copyctor, Whitespace.NL])
# Union(Union&&) move ctor
movector = ConstructorDefn(
ConstructorDecl(ud.name, params=[Decl(rvalueRefClsType, othervar.name)])
)
othertypevar = ExprVar(
"t")
moveswitch = StmtSwitch(othertypevar)
for c
in ud.components:
case = StmtBlock()
if c.recursive:
# This is sound as we set othervar.mTypeVar to T__None after the
# switch. The pointer in the union will be left dangling.
case.addstmts(
[
# ptr_C() = other.ptr_C()
StmtExpr(
ExprAssn(
c.callGetPtr(),
ExprCall(
ExprSelect(othervar,
".", ExprVar(c.getPtrName()))
),
)
)
]
)
else:
case.addstmts(
[
# new ... (Move(other.get_C()))
StmtExpr(
c.callCtor(
ExprMove(
ExprCall(ExprSelect(othervar,
".", c.getTypeName()))
)
)
),
# other.MaybeDestroy(T__None)
StmtExpr(ExprCall(ExprSelect(othervar,
".", maybedtorvar))),
]
)
case.addstmts([StmtBreak()])
moveswitch.addcase(CaseLabel(c.enum()), case)
moveswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()]))
moveswitch.addcase(
DefaultLabel(), StmtBlock([_logicError(
"unreached"), StmtReturn()])
)
movector.addstmts(
[
StmtExpr(callAssertSanity(uvar=othervar)),
StmtDecl(Decl(typetype, othertypevar.name), init=ud.callType(othervar)),
moveswitch,
StmtExpr(ExprAssn(ExprSelect(othervar,
".", mtypevar), tnonevar)),
StmtExpr(ExprAssn(mtypevar, othertypevar)),
]
)
cls.addstmts([movector, Whitespace.NL])
# ~Union()
dtor = DestructorDefn(DestructorDecl(ud.name))
dtor.addstmt(maybeDestroy())
cls.addstmts([dtor, Whitespace.NL])
# type()
typemeth = MethodDefn(
MethodDecl(
"type", ret=typetype, const=
True, force_inline=
True)
)
typemeth.addstmt(StmtReturn(mtypevar))
cls.addstmts([typemeth, Whitespace.NL])
# Union& operator= methods
rhsvar = ExprVar(
"aRhs")
for c
in ud.components:
def opeqBody(rhs):
return [
# might need to placement-delete old value first
maybeDestroy(),
StmtExpr(c.callCtor(rhs)),
StmtExpr(ExprAssn(mtypevar, c.enumvar())),
StmtReturn(ExprDeref(ExprVar.THIS)),
]
if not _cxxTypeNeedsMoveForData(c.ipdltype):
# Union& operator=(const T&)
opeq = MethodDefn(
MethodDecl(
"operator=",
params=[Decl(c.constRefType(), rhsvar.name)],
ret=refClsType,
)
)
opeq.addstmts(opeqBody(rhsvar))
cls.addstmts([opeq, Whitespace.NL])
# Union& operator=(T&&)
if not _cxxTypeCanMove(c.ipdltype):
continue
opeq = MethodDefn(
MethodDecl(
"operator=",
params=[Decl(c.forceMoveType(), rhsvar.name)],
ret=refClsType,
)
)
opeq.addstmts(opeqBody(ExprMove(rhsvar)))
cls.addstmts([opeq, Whitespace.NL])
# Union& operator=(const Union&)
if not unionNeedsMove:
opeq = MethodDefn(
MethodDecl(
"operator=", params=[Decl(inClsType, rhsvar.name)], ret=refClsType
)
)
rhstypevar = ExprVar(
"t")
opeqswitch = StmtSwitch(rhstypevar)
for c
in ud.components:
case = StmtBlock()
case.addstmts(
[
maybeDestroy(),
StmtExpr(
c.callCtor(
ExprCall(ExprSelect(rhsvar,
".", c.getConstTypeName()))
)
),
StmtBreak(),
]
)
opeqswitch.addcase(CaseLabel(c.enum()), case)
opeqswitch.addcase(
CaseLabel(tnonevar.name),
StmtBlock([maybeDestroy(), StmtBreak()]),
)
opeqswitch.addcase(
DefaultLabel(), StmtBlock([_logicError(
"unreached"), StmtBreak()])
)
opeq.addstmts(
[
StmtExpr(callAssertSanity(uvar=rhsvar)),
StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)),
opeqswitch,
StmtExpr(ExprAssn(mtypevar, rhstypevar)),
StmtReturn(ExprDeref(ExprVar.THIS)),
]
)
cls.addstmts([opeq, Whitespace.NL])
# Union& operator=(Union&&)
opeq = MethodDefn(
MethodDecl(
"operator=", params=[Decl(rvalueRefClsType, rhsvar.name)], ret=refClsType
)
)
rhstypevar = ExprVar(
"t")
opeqswitch = StmtSwitch(rhstypevar)
for c
in ud.components:
case = StmtBlock()
if c.recursive:
case.addstmts(
[
maybeDestroy(),
StmtExpr(
ExprAssn(
c.callGetPtr(),
ExprCall(ExprSelect(rhsvar,
".", ExprVar(c.getPtrName()))),
)
),
]
)
else:
case.addstmts(
[
maybeDestroy(),
StmtExpr(
c.callCtor(
ExprMove(ExprCall(ExprSelect(rhsvar,
".", c.getTypeName())))
)
),
# other.MaybeDestroy()
StmtExpr(ExprCall(ExprSelect(rhsvar,
".", maybedtorvar))),
]
)
case.addstmts([StmtBreak()])
opeqswitch.addcase(CaseLabel(c.enum()), case)
opeqswitch.addcase(
CaseLabel(tnonevar.name),
StmtBlock([maybeDestroy(), StmtBreak()]),
)
opeqswitch.addcase(
DefaultLabel(), StmtBlock([_logicError(
"unreached"), StmtBreak()])
)
opeq.addstmts(
[
StmtExpr(callAssertSanity(uvar=rhsvar)),
StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)),
opeqswitch,
StmtExpr(ExprAssn(ExprSelect(rhsvar,
".", mtypevar), tnonevar)),
StmtExpr(ExprAssn(mtypevar, rhstypevar)),
StmtReturn(ExprDeref(ExprVar.THIS)),
]
)
cls.addstmts([opeq, Whitespace.NL])
if "Comparable" in ud.attributes:
# bool operator==(const T&)
for c
in ud.components:
opeqeq = MethodDefn(
MethodDecl(
"operator==",
params=[Decl(c.constRefType(), rhsvar.name)],
ret=Type.BOOL,
const=
True,
)
)
opeqeq.addstmt(
StmtReturn(ExprBinary(ExprCall(ExprVar(c.getTypeName())),
"==", rhsvar))
)
cls.addstmts([opeqeq, Whitespace.NL])
# bool operator==(const Union&)
opeqeq = MethodDefn(
MethodDecl(
"operator==",
params=[Decl(inClsType, rhsvar.name)],
ret=Type.BOOL,
const=
True,
)
)
iftypesmismatch = StmtIf(ExprBinary(ud.callType(),
"!=", ud.callType(rhsvar)))
iftypesmismatch.addifstmt(StmtReturn.
FALSE)
opeqeq.addstmts([iftypesmismatch, Whitespace.NL])
opeqeqswitch = StmtSwitch(ud.callType())
for c
in ud.components:
case = StmtBlock()
case.addstmt(
StmtReturn(
ExprBinary(
ExprCall(ExprVar(c.getTypeName())),
"==",
ExprCall(ExprSelect(rhsvar,
".", c.getTypeName())),
)
)
)
opeqeqswitch.addcase(CaseLabel(c.enum()), case)
opeqeqswitch.addcase(
DefaultLabel(), StmtBlock([_logicError(
"unreached"), StmtReturn.
FALSE])
)
opeqeq.addstmt(opeqeqswitch)
cls.addstmts([opeqeq, Whitespace.NL])
# accessors for each type: operator T&, operator const T&,
# T& get(), const T& get()
for c
in ud.components:
getValueVar = ExprVar(c.getTypeName())
getConstValueVar = ExprVar(c.getConstTypeName())
getvalue = MethodDefn(
MethodDecl(getValueVar.name, ret=c.refType(), force_inline=
True)
)
getvalue.addstmts(
[
StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())),
StmtReturn(ExprDeref(c.callGetPtr())),
]
)
getconstvalue = MethodDefn(
MethodDecl(
getConstValueVar.name,
ret=c.constRefType(),
const=
True,
force_inline=
True,
)
)
getconstvalue.addstmts(
[
StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())),
StmtReturn(c.getConstValue()),
]
)
cls.addstmts([getvalue, getconstvalue])
optype = MethodDefn(MethodDecl(
"", typeop=c.refType(), force_inline=
True))
optype.addstmt(StmtReturn(ExprCall(getValueVar)))
opconsttype = MethodDefn(
MethodDecl(
"", const=
True, typeop=c.constRefType(), force_inline=
True)
)
opconsttype.addstmt(StmtReturn(ExprCall(getConstValueVar)))
cls.addstmts([optype, opconsttype, Whitespace.NL])
# private vars
cls.addstmts(
[
Label.PRIVATE,
StmtDecl(Decl(valuetype,
"")),
StmtDecl(Decl(typetype, mtypevar.name)),
]
)
return forwarddeclstmts, fulldecltypes, cls
# -----------------------------------------------------------------------------
class _FindFriends(ipdl.ast.Visitor):
def __init__(self):
self.mytype =
None # ProtocolType
self.vtype =
None # ProtocolType
self.friends = set()
# set<ProtocolType>
def findFriends(self, ptype):
self.mytype = ptype
for toplvl
in ptype.toplevels():
self.walkDownTheProtocolTree(toplvl)
return self.friends
# TODO could make this into a _iterProtocolTreeHelper ...
def walkDownTheProtocolTree(self, ptype):
if ptype != self.mytype:
# don't want to |friend| ourself!
self.visit(ptype)
for mtype
in ptype.manages:
if mtype
is not ptype:
self.walkDownTheProtocolTree(mtype)
def visit(self, ptype):
# |vtype| is the type currently being visited
savedptype = self.vtype
self.vtype = ptype
ptype._ast.accept(self)
self.vtype = savedptype
def visitMessageDecl(self, md):
for it
in self.iterActorParams(md):
if it.protocol == self.mytype:
self.friends.add(self.vtype)
def iterActorParams(self, md):
for param
in md.inParams:
for actor
in ipdl.type.iteractortypes(param.type):
yield actor
for ret
in md.outParams:
for actor
in ipdl.type.iteractortypes(ret.type):
yield actor
class _GenerateProtocolActorCode(ipdl.ast.Visitor):
def __init__(self, myside):
self.side = myside
# "parent" or "child"
self.prettyside = myside.title()
self.clsname =
None
self.protocol =
None
self.hdrfile =
None
self.cppfile =
None
self.ns =
None
self.cls =
None
self.protocolCxxIncludes = []
self.actorForwardDecls = []
self.usingDecls = []
self.externalIncludes = set()
self.nonForwardDeclaredHeaders = set()
self.typedefSet = set(
[
Typedef(Type(
"mozilla::ipc::ActorHandle"),
"ActorHandle"),
Typedef(Type(
"base::ProcessId"),
"ProcessId"),
Typedef(Type(
"mozilla::ipc::ProtocolId"),
"ProtocolId"),
Typedef(Type(
"mozilla::ipc::Endpoint"),
"Endpoint", [
"FooSide"]),
Typedef(
Type(
"mozilla::ipc::ManagedEndpoint"),
"ManagedEndpoint",
[
"FooSide"],
),
Typedef(Type(
"mozilla::UniquePtr"),
"UniquePtr", [
"T"]),
Typedef(
Type(
"mozilla::ipc::ResponseRejectReason"),
"ResponseRejectReason"
),
]
)
def lower(self, tu, clsname, cxxHeaderFile, cxxFile):
self.clsname = clsname
self.hdrfile = cxxHeaderFile
self.cppfile = cxxFile
tu.accept(self)
def standardTypedefs(self):
return [
Typedef(Type(
"mozilla::ipc::IProtocol"),
"IProtocol"),
Typedef(Type(
"IPC::Message"),
"Message"),
Typedef(Type(
"base::ProcessHandle"),
"ProcessHandle"),
Typedef(Type(
"mozilla::ipc::MessageChannel"),
"MessageChannel"),
Typedef(Type(
"mozilla::ipc::SharedMemory"),
"SharedMemory"),
]
def visitTranslationUnit(self, tu):
self.protocol = tu.protocol
hf = self.hdrfile
cf = self.cppfile
# make the C++ header
hf.addthings(
[_DISCLAIMER]
+ _includeGuardStart(hf)
+ [
Whitespace.NL,
CppDirective(
"include",
'"' + _protocolHeaderName(tu.protocol) +
'.h"'),
]
)
for inc
in tu.includes:
inc.accept(self)
for inc
in tu.cxxIncludes:
inc.accept(self)
for using
in tu.builtinUsing:
using.accept(self)
for using
in tu.using:
using.accept(self)
for su
in tu.structsAndUnions:
su.accept(self)
# this generates the actor's full impl in self.cls
tu.protocol.accept(self)
clsdecl, clsdefn = _splitClassDeclDefn(self.cls)
# XXX damn C++ ... return types in the method defn aren't in
# class scope
for stmt
in clsdefn.stmts:
if isinstance(stmt, MethodDefn):
if stmt.decl.ret
and stmt.decl.ret.name ==
"Result":
stmt.decl.ret.name = clsdecl.name +
"::" + stmt.decl.ret.name
def setToIncludes(s):
return [CppDirective(
"include",
'"%s"' % i)
for i
in sorted(iter(s))]
def makeNamespace(p, file):
if 0 == len(p.namespaces):
return file
ns = Namespace(p.namespaces[-1].name)
outerns = _putInNamespaces(ns, p.namespaces[:-1])
file.addthing(outerns)
return ns
if len(self.nonForwardDeclaredHeaders) != 0:
self.hdrfile.addthings(
[
Whitespace(
"// Headers for things that cannot be forward declared"),
Whitespace.NL,
]
+ setToIncludes(self.nonForwardDeclaredHeaders)
+ [Whitespace.NL]
)
self.hdrfile.addthings(self.actorForwardDecls)
self.hdrfile.addthings(self.usingDecls)
hdrns = makeNamespace(self.protocol, self.hdrfile)
hdrns.addstmts(
[Whitespace.NL, Whitespace.NL, clsdecl, Whitespace.NL, Whitespace.NL]
)
actortype = ActorType(tu.protocol.decl.type)
traitsdecl, traitsdefn = _ParamTraits.actorPickling(actortype, self.side)
self.hdrfile.addthings([traitsdecl, Whitespace.NL] + _includeGuardEnd(hf))
# If the implementation type is not overridden, add an implicit import
# for the default implementation header file. Explicit implementation
# types will specify their headers manually with `include`.
if self.protocol.implAttribute(self.side)
is None:
assert self.protocol.name.startswith(
"P")
self.externalIncludes.add(
"".join(n.name +
"/" for n
in self.protocol.namespaces)
+ self.protocol.name[1:]
+ self.side.capitalize()
+
".h"
)
# make the .cpp file
cf.addthings(
[
_DISCLAIMER,
Whitespace.NL,
CppDirective(
"include",
'"' + _protocolHeaderName(self.protocol, self.side) +
'.h"',
),
]
+ setToIncludes(self.externalIncludes)
)
cf.addthings(
(
[Whitespace.NL]
+ [
CppDirective(
"include",
'"%s.h"' % (inc))
for inc
in self.protocolCxxIncludes
]
+ [Whitespace.NL]
+ [
CppDirective(
"include",
'"%s"' % filename)
for filename
in ipdl.builtin.CppIncludes
]
+ [Whitespace.NL]
)
)
cppns = makeNamespace(self.protocol, cf)
cppns.addstmts(
[Whitespace.NL, Whitespace.NL, clsdefn, Whitespace.NL, Whitespace.NL]
)
cf.addthing(traitsdefn)
def visitUsingStmt(self, using):
if using.decl.fullname
is not None:
self.typedefSet.add(
Typedef(Type(using.decl.fullname), using.decl.shortname)
)
if using.header
is None:
return
if using.canBeForwardDeclared():
spec = using.type
self.usingDecls.extend(
[
_makeForwardDeclForQClass(
spec.baseid,
spec.quals,
cls=using.isClass(),
struct=using.isStruct(),
),
Whitespace.NL,
]
)
self.externalIncludes.add(using.header)
else:
self.nonForwardDeclaredHeaders.add(using.header)
def visitCxxInclude(self, inc):
self.externalIncludes.add(inc.file)
def visitInclude(self, inc):
if inc.tu.filetype ==
"header":
# Including a header will declare any globals defined by "using"
# statements into our scope. To serialize these, we also may need
# cxx include statements, so visit them as well.
for cxxinc
in inc.tu.cxxIncludes:
cxxinc.accept(self)
for using
in inc.tu.using:
using.accept(self)
for su
in inc.tu.structsAndUnions:
su.accept(self)
else:
# Includes for protocols only include types explicitly exported by
# those protocols.
ip = inc.tu.protocol
if ip == self.protocol:
return
self.actorForwardDecls.extend(
[
_makeForwardDeclForActor(ip.decl.type, self.side),
_makeForwardDeclForActor(ip.decl.type, _otherSide(self.side)),
Whitespace.NL,
]
)
self.protocolCxxIncludes.append(_protocolHeaderName(ip, self.side))
if ip.decl.fullname
is not None:
self.typedefSet.add(
Typedef(
Type(_actorName(ip.decl.fullname, self.side.title())),
_actorName(ip.decl.shortname, self.side.title()),
)
)
self.typedefSet.add(
Typedef(
Type(
_actorName(ip.decl.fullname, _otherSide(self.side).title())
),
_actorName(ip.decl.shortname, _otherSide(self.side).title()),
)
)
def visitStructDecl(self, sd):
if sd.decl.fullname
is not None:
self.typedefSet.add(Typedef(Type(sd.fqClassName()), sd.name))
def visitUnionDecl(self, ud):
if ud.decl.fullname
is not None:
self.typedefSet.add(Typedef(Type(ud.fqClassName()), ud.name))
def visitProtocol(self, p):
self.hdrfile.addcode(
"""
#ifdef DEBUG
#include "prenv.h"
#endif // DEBUG
#include "mozilla/Tainting.h"
#include "mozilla/ipc/MessageChannel.h"
#include "mozilla/ipc/ProtocolUtils.h"
"""
)
self.protocol = p
ptype = p.decl.type
toplevel = p.decl.type.toplevel()
hasAsyncReturns =
False
for md
in p.messageDecls:
if md.hasAsyncReturns():
hasAsyncReturns =
True
break
if ptype.isToplevel():
inherits = [Inherit(Type(
"mozilla::ipc::IToplevelProtocol"))]
elif ptype.isRefcounted():
inherits = [Inherit(Type(
"mozilla::ipc::IRefCountedProtocol"))]
else:
inherits = [Inherit(Type(
"mozilla::ipc::IProtocol"))]
if ptype.isToplevel()
and self.side ==
"parent":
self.hdrfile.addthings(
[_makeForwardDeclForQClass(
"nsIFile", []), Whitespace.NL]
)
self.cls =
Class(self.clsname, inherits=inherits, abstract=
True)
self.cls.addstmt(Label.PRIVATE)
friends = _FindFriends().findFriends(ptype)
if ptype.isManaged():
friends.update(ptype.managers)
# |friend| managed actors so that they can call our Dealloc*()
friends.update(ptype.manages)
# don't friend ourself if we're a self-managed protocol
friends.discard(ptype)
for friend
in sorted(friends, key=
lambda f: f.fullname()):
self.actorForwardDecls.extend(
[_makeForwardDeclForActor(friend, self.prettyside), Whitespace.NL]
)
self.cls.addstmt(
FriendClassDecl(_actorName(friend.fullname(), self.prettyside))
)
self.cls.addstmt(Label.PUBLIC)
protocolid = ExprVar(
"kProtocolId")
self.cls.addcode(
"""
static constexpr ::mozilla::ipc::ProtocolId ${name} = ${value};
""",
name=protocolid,
value=_protocolId(ptype),
)
self.cls.addstmt(Label.PROTECTED)
for typedef
in sorted(self.typedefSet):
self.cls.addstmt(typedef)
self.cls.addstmt(Whitespace.NL)
if hasAsyncReturns:
self.cls.addstmt(Label.PUBLIC)
for md
in p.messageDecls:
if self.sendsMessage(md)
and md.hasAsyncReturns():
self.cls.addstmt(
Typedef(_makePromise(md.returns, self.side), md.promiseName())
)
if self.receivesMessage(md)
and md.hasAsyncReturns():
self.cls.addstmt(
Typedef(_makeResolver(md.returns, self.side), md.resolverName())
)
self.cls.addstmt(Whitespace.NL)
self.cls.addstmt(Label.PROTECTED)
# interface methods that the concrete subclass has to impl
for md
in p.messageDecls:
isctor, isdtor = md.decl.type.isCtor(), md.decl.type.isDtor()
if self.receivesMessage(md):
# generate Recv/Answer* interface
implicit =
not isdtor
returnsems =
"resolver" if md.decl.type.isAsync()
else "out"
recvDecl = MethodDecl(
md.recvMethod(),
params=md.makeCxxParams(
paramsems=
"move",
returnsems=returnsems,
side=self.side,
implicit=implicit,
direction=
"recv",
),
ret=Type(
"mozilla::ipc::IPCResult"),
methodspec=MethodSpec.VIRTUAL,
)
# These method implementations cause problems when trying to
# override them with different types in a direct call class.
#
# For the `isdtor` case there's a simple solution: it doesn't
# make much sense to specify arguments and then completely
# ignore them, and the no-arg case isn't a problem for
# overriding.
if isctor
or (isdtor
and not md.inParams):
defaultRecv = MethodDefn(recvDecl)
defaultRecv.addcode(
"return IPC_OK();\n")
self.cls.addstmt(defaultRecv)
elif self.protocol.implAttribute(self.side) ==
"virtual":
# If we're using virtual calls, we need the methods to be
# declared on the base class.
recvDecl.methodspec = MethodSpec.PURE
self.cls.addstmt(StmtDecl(recvDecl))
# If we're using virtual calls, we need the methods to be declared on
# the base class.
if self.protocol.implAttribute(self.side) ==
"virtual":
for md
in p.messageDecls:
managed = md.decl.type.constructedType()
if not ptype.isManagerOf(managed)
or md.decl.type.isDtor():
continue
# add the Alloc interface for managed actors
actortype = md.actorDecl().bareType(self.side)
if managed.isRefcounted():
if not self.receivesMessage(md):
continue
actortype.ptr =
False
actortype = _alreadyaddrefed(actortype)
self.cls.addstmt(
StmtDecl(
MethodDecl(
_allocMethod(managed, self.side),
params=md.makeCxxParams(
side=self.side, implicit=
False, direction=
"recv"
),
ret=actortype,
methodspec=MethodSpec.PURE,
)
)
)
# add the Dealloc interface for all managed non-refcounted actors,
# even without ctors. This is useful for protocols which use
# ManagedEndpoint for construction.
for managed
in ptype.manages:
if managed.isRefcounted():
continue
self.cls.addstmt(
StmtDecl(
MethodDecl(
_deallocMethod(managed, self.side),
params=[
Decl(p.managedCxxType(managed, self.side),
"aActor")
],
ret=Type.BOOL,
methodspec=MethodSpec.PURE,
)
)
)
if ptype.isToplevel():
# void ProcessingError(code); default to no-op
processingerror = MethodDefn(
MethodDecl(
p.processingErrorVar().name,
params=[
Param(_Result.Type(),
"aCode"),
Param(Type(
"char", const=
True, ptr=
True),
"aReason"),
],
methodspec=MethodSpec.OVERRIDE,
)
)
# bool ShouldContinueFromReplyTimeout(); default to |true|
shouldcontinue = MethodDefn(
MethodDecl(
p.shouldContinueFromTimeoutVar().name,
ret=Type.BOOL,
methodspec=MethodSpec.OVERRIDE,
)
)
shouldcontinue.addcode(
"return true;\n")
self.cls.addstmts(
[
processingerror,
shouldcontinue,
Whitespace.NL,
]
)
self.cls.addstmts(([Label.PUBLIC] + self.standardTypedefs() + [Whitespace.NL]))
self.cls.addstmt(Label.PUBLIC)
# Actor()
ctor = ConstructorDefn(ConstructorDecl(self.clsname))
side = ExprVar(
"mozilla::ipc::" + self.side.title() +
"Side")
if ptype.isToplevel():
name = ExprLiteral.String(_actorName(p.name, self.side))
ctor.memberinits = [
ExprMemberInit(
ExprVar(
"mozilla::ipc::IToplevelProtocol"),
[name, protocolid, side],
)
]
else:
baseCtor = (
ExprVar(
"mozilla::ipc::IRefCountedProtocol")
if ptype.isRefcounted()
else ExprVar(
"mozilla::ipc::IProtocol")
)
ctor.memberinits = [ExprMemberInit(baseCtor, [protocolid, side])]
ctor.addcode(
"MOZ_COUNT_CTOR(${clsname});\n", clsname=self.clsname)
self.cls.addstmts([ctor, Whitespace.NL])
# ~Actor()
dtor = DestructorDefn(
DestructorDecl(self.clsname, methodspec=MethodSpec.VIRTUAL)
)
dtor.addcode(
"MOZ_COUNT_DTOR(${clsname});\n", clsname=self.clsname)
self.cls.addstmts([dtor, Whitespace.NL])
# ActorAlloc() and ActorDealloc()
actoralloc = MethodDefn(MethodDecl(
"ActorAlloc", methodspec=MethodSpec.FINAL))
actordealloc = MethodDefn(
MethodDecl(
"ActorDealloc", methodspec=MethodSpec.FINAL)
)
# Assert process type in ActorAlloc
procattr = p.procAttribute(self.side)
if procattr
not in (
"any",
None):
if procattr ==
"anychild":
procattr_assertion =
"!XRE_IsParentProcess()"
elif procattr ==
"anydom":
procattr_assertion =
"XRE_IsParentProcess() || XRE_IsContentProcess()"
elif procattr ==
"compositor":
procattr_assertion =
"XRE_IsParentProcess() || XRE_IsGPUProcess()"
else:
procattr_assertion =
"XRE_Is%sProcess()" % procattr
actoralloc.addcode(
"MOZ_RELEASE_ASSERT(${assertion}, ${message});\n",
assertion=procattr_assertion,
message=ExprLiteral.String(
"Invalid process for `%s'" % self.clsname),
)
if ptype.isRefcounted():
# Perform AddRef/Release in ActorAlloc/ActorDealloc if refcounted.
actoralloc.addcode(
"AddRef();\n")
actordealloc.addcode(
"Release();\n")
elif not ptype.isToplevel():
# If we're a managed actor with [ManualDealloc], use DeallocManagee
# to invoke the relevant Dealloc method.
actordealloc.addcode(
"""
if (Manager()) {
Manager()->DeallocManagee(${protocolId}, this);
}
""",
protocolId=protocolid,
)
self.cls.addstmts([Label.PROTECTED, actoralloc, actordealloc])
self.cls.addstmt(Label.PUBLIC)
if ptype.hasOtherPid():
otherpidmeth = MethodDefn(
MethodDecl(
"OtherPid", ret=Type(
"::base::ProcessId"), const=
True)
)
otherpidmeth.addcode(
"""
::base::ProcessId pid =
::mozilla::ipc::IProtocol::ToplevelProtocol()->OtherPidMaybeInvalid();
MOZ_RELEASE_ASSERT(pid != ::base::kInvalidProcessId);
return pid;
"""
)
otherchildidmeth = MethodDefn(
MethodDecl(
"OtherChildID", ret=Type(
"::GeckoChildID"), const=
True)
)
otherchildidmeth.addcode(
"""
::GeckoChildID childID =
::mozilla::ipc::IProtocol::ToplevelProtocol()->OtherChildIDMaybeInvalid();
MOZ_RELEASE_ASSERT(childID != -1);
return childID;
"""
)
otherendpointprocinfometh = MethodDefn(
MethodDecl(
"OtherEndpointProcInfo",
ret=Type(
"::mozilla::ipc::EndpointProcInfo"),
const=
True,
)
)
otherendpointprocinfometh.addcode(
"""
return ::mozilla::ipc::EndpointProcInfo{OtherPid(), OtherChildID()};
"""
)
self.cls.addstmts(
[
otherpidmeth,
otherchildidmeth,
otherendpointprocinfometh,
Whitespace.NL,
]
)
if not ptype.isToplevel():
if 1 == len(p.managers):
# manager() const
managertype = p.managerActorType(self.side, ptr=
True)
managermeth = MethodDefn(
MethodDecl(
"Manager", ret=managertype, const=
True)
)
managermeth.addcode(
"""
return static_cast<${type}>(IProtocol::Manager());
""",
type=managertype,
)
self.cls.addstmts([managermeth, Whitespace.NL])
# Managed[T](Array& inout) const
# const Array<T>& Managed() const
for managed
in ptype.manages:
container = p.managedVar(managed, self.side)
meth = MethodDefn(
MethodDecl(
p.managedMethod(managed, self.side).name,
params=[
Decl(
_cxxArrayType(
p.managedCxxType(managed, self.side), ref=
True
),
"aArr",
)
],
const=
True,
)
)
meth.addcode(
"${container}.ToArray(aArr);\n", container=container)
refmeth = MethodDefn(
MethodDecl(
p.managedMethod(managed, self.side).name,
params=[],
ret=p.managedVarType(managed, self.side, const=
True, ref=
True),
const=
True,
)
)
refmeth.addcode(
"return ${container};\n", container=container)
self.cls.addstmts([meth, refmeth, Whitespace.NL])
managedprotocolids = MethodDefn(
MethodDecl(
"ManagedProtocolIds",
ret=_cxxSpanType(Type(
"mozilla::ipc::ProtocolId")),
methodspec=MethodSpec.FINAL,
const=
True,
)
)
if ptype.manages:
managedprotocolids.addcode(
"""
static constexpr ::mozilla::ipc::ProtocolId sIds[] = {
$,{values}
};
return sIds;
""",
values=[_protocolId(managed)
for managed
in ptype.manages],
)
else:
managedprotocolids.addcode(
"return {};\n")
self.cls.addstmts([managedprotocolids, Whitespace.NL])
# UntypedManagedContainer& GetManagedActors(ProtocolId)
managedactors = MethodDefn(
MethodDecl(
"GetManagedActors",
ret=Type(
"UntypedManagedContainer", ptr=
True),
methodspec=MethodSpec.FINAL,
params=[Decl(Type(
"mozilla::ipc::ProtocolId"),
"aProtocol")],
)
)
selectmanagedactors = StmtSwitch(ExprVar(
"aProtocol"))
for managed
in ptype.manages:
selectmanagedactors.addcase(
CaseLabel(_protocolId(managed).name),
StmtReturn(ExprAddrOf(p.managedVar(managed, self.side))),
)
selectmanagedactors.addcase(DefaultLabel(), StmtReturn(ExprLiteral.NULL))
managedactors.addstmt(selectmanagedactors)
self.cls.addstmts([managedactors, Whitespace.NL])
# void RejectPendingResponses()
if hasAsyncReturns:
rejectpendingresponses = MethodDefn(
MethodDecl(
"RejectPendingResponses",
methodspec=MethodSpec.FINAL,
params=[Decl(_ResponseRejectReason.Type(),
"aReason")],
)
)
rejectpendingresponses.addcode(
"""
mAsyncCallbacks.RejectPendingResponses(aReason);
""",
)
self.cls.addstmts([rejectpendingresponses, Whitespace.NL])
# OpenPEndpoint(...)/BindPEndpoint(...)
for managed
in ptype.manages:
self.genManagedEndpoint(managed)
# OnMessageReceived()
# save these away for use in message handler case stmts
msgvar = ExprVar(
"msg__")
self.msgvar = msgvar
replyvar = ExprVar(
"reply__")
self.replyvar = replyvar
var = ExprVar(
"v__")
self.var = var
# for ctor recv cases, we can't read the actor ID into a PFoo*
# because it doesn't exist on this side yet. Use a "special"
# actor handle instead
handlevar = ExprVar(
"handle__")
self.handlevar = handlevar
msgtype = ExprCode(
"msg__.type()")
self.asyncSwitch = StmtSwitch(msgtype)
self.syncSwitch =
None
if toplevel.isSync():
self.syncSwitch = StmtSwitch(msgtype)
# Add a handler for the MANAGED_ENDPOINT_BOUND and
# MANAGED_ENDPOINT_DROPPED message types for managed actors.
if not ptype.isToplevel():
clearawaitingmanagedendpointbind =
"""
if (!mAwaitingManagedEndpointBind) {
NS_WARNING(
"Unexpected managed endpoint lifecycle message after actor bound!");
return MsgNotAllowed;
}
mAwaitingManagedEndpointBind =
false;
"""
self.asyncSwitch.addcase(
CaseLabel(
"MANAGED_ENDPOINT_BOUND_MESSAGE_TYPE"),
StmtBlock(
[
StmtCode(clearawaitingmanagedendpointbind),
StmtReturn(_Result.Processed),
]
),
)
self.asyncSwitch.addcase(
CaseLabel(
"MANAGED_ENDPOINT_DROPPED_MESSAGE_TYPE"),
StmtBlock(
[
StmtCode(clearawaitingmanagedendpointbind),
*self.destroyActor(
None,
ExprVar.THIS,
why=_DestroyReason.ManagedEndpointDropped,
),
StmtReturn(_Result.Processed),
]
),
)
# implement Send*() methods and add dispatcher cases to
# message switch()es
for md
in p.messageDecls:
self.visitMessageDecl(md)
# add default cases
default = StmtCode(
"""
return MsgNotKnown;
"""
)
self.asyncSwitch.addcase(DefaultLabel(), default)
if toplevel.isSync():
self.syncSwitch.addcase(DefaultLabel(), default)
self.cls.addstmts(self.implementManagerIface())
def makeHandlerMethod(name, switch, hasReply, dispatches=
False):
params = [Decl(Type(
"Message", const=
True, ref=
True), msgvar.name)]
if hasReply:
params.append(Decl(Type(
"UniquePtr<Message>", ref=
True), replyvar.name))
method = MethodDefn(
MethodDecl(
name,
methodspec=MethodSpec.OVERRIDE,
params=params,
ret=_Result.Type(),
)
)
if not switch:
method.addcode(
"""
MOZ_ASSERT_UNREACHABLE(
"message protocol not supported");
return MsgNotKnown;
"""
)
return method
if dispatches:
if hasReply:
ondeadactor = [StmtReturn(_Result.Dropped)]
else:
ondeadactor = [
self.logMessage(
None, ExprAddrOf(msgvar),
"Ignored message for dead actor"
),
StmtReturn(_Result.Processed),
]
method.addcode(
"""
int32_t route__ = ${msgvar}.routing_id();
if (MSG_ROUTING_CONTROL != route__) {
IProtocol* routed__ = Lookup(route__);
if (!routed__ || !routed__->GetLifecycleProxy()) {
$*{ondeadactor}
}
RefPtr<mozilla::ipc::ActorLifecycleProxy> proxy__ =
routed__->GetLifecycleProxy();
return proxy__->Get()->${name}($,{args});
}
""",
msgvar=msgvar,
ondeadactor=ondeadactor,
name=name,
args=[p.name
for p
in params],
)
# bug 509581: don't generate the switch stmt if there
# is only the default case; MSVC doesn't like that
if switch.nr_cases > 1:
method.addstmt(switch)
else:
method.addstmt(StmtReturn(_Result.NotKnown))
return method
dispatches = ptype.isToplevel()
and ptype.isManager()
self.cls.addstmts(
[
makeHandlerMethod(
"OnMessageReceived",
self.asyncSwitch,
hasReply=
False,
dispatches=dispatches,
),
Whitespace.NL,
]
)
self.cls.addstmts(
[
makeHandlerMethod(
"OnMessageReceived",
self.syncSwitch,
hasReply=
True,
dispatches=dispatches,
),
Whitespace.NL,
]
)
# private methods
self.cls.addstmt(Label.PRIVATE)
if hasAsyncReturns:
self.cls.addstmts(
[
StmtDecl(
Decl(
Type(
"mozilla::ipc::IPDLAsyncReturnsCallbacks"),
"mAsyncCallbacks",
)
),
Whitespace.NL,
]
)
if not ptype.isToplevel():
self.cls.addstmts(
[
StmtDecl(
Decl(Type.BOOL,
"mAwaitingManagedEndpointBind"),
init=ExprLiteral.
FALSE,
),
Whitespace.NL,
]
)
for managed
in ptype.manages:
self.cls.addstmts(
[
StmtDecl(
Decl(
p.managedVarType(managed, self.side),
p.managedVar(managed, self.side).name,
)
)
]
)
def genManagedEndpoint(self, managed):
hereEp =
"ManagedEndpoint<%s>" % _actorName(managed.name(), self.side)
thereEp =
"ManagedEndpoint<%s>" % _actorName(
managed.name(), _otherSide(self.side)
)
actor = _HybridDecl(ipdl.type.ActorType(managed),
"aActor")
# ManagedEndpoint<PThere> OpenPEndpoint(PHere* aActor)
openmeth = MethodDefn(
MethodDecl(
"Open%sEndpoint" % managed.name(),
params=[
Decl(self.protocol.managedCxxType(managed, self.side), actor.name)
],
ret=Type(thereEp),
)
)
openmeth.addcode(
"""
$*{bind}
// Mark our actor
as awaiting the other side to be bound. This will
// be cleared when a `MANAGED_ENDPOINT_{DROPPED,BOUND}` message
is
// received.
aActor->mAwaitingManagedEndpointBind =
true;
return ${thereEp}(mozilla::ipc::PrivateIPDLInterface(), aActor);
""",
bind=self.bindManagedActor(actor, errfn=ExprCall(ExprVar(thereEp))),
thereEp=thereEp,
)
# void BindPEndpoint(ManagedEndpoint<PHere>&& aEndpoint, PHere* aActor)
bindmeth = MethodDefn(
MethodDecl(
"Bind%sEndpoint" % managed.name(),
params=[
Decl(Type(hereEp),
"aEndpoint"),
Decl(self.protocol.managedCxxType(managed, self.side), actor.name),
],
ret=Type.BOOL,
)
)
bindmeth.addcode(
"""
return aEndpoint.Bind(mozilla::ipc::PrivateIPDLInterface(), aActor, this);
"""
)
self.cls.addstmts([openmeth, bindmeth, Whitespace.NL])
def implementManagerIface(self):
p = self.protocol
protocolbase = Type(
"IProtocol", ptr=
True)
methods = []
if p.decl.type.isToplevel():
# FIXME: This used to be declared conditionally based on whether
# shmem appeared somewhere in the protocol hierarchy, however that
# caused issues due to Shmem instances hidden within custom C++
# types.
self.asyncSwitch.addcase(
CaseLabel(
"SHMEM_CREATED_MESSAGE_TYPE"),
self.genShmemCreatedHandler(),
)
self.asyncSwitch.addcase(
CaseLabel(
"SHMEM_DESTROYED_MESSAGE_TYPE"),
self.genShmemDestroyedHandler(),
)
# Keep track of types created with an INOUT ctor. We need to call
# Register() or RegisterID() for them depending on the side the managee
# is created.
inoutCtorTypes = []
for msg
in p.messageDecls:
msgtype = msg.decl.type
if msgtype.isCtor()
and msgtype.isInout():
inoutCtorTypes.append(msgtype.constructedType())
# The `DeallocManagee` method is called for managed actors to trigger
# deallocation when ActorLifecycleProxy is freed.
pvar = ExprVar(
"aProtocolId")
listenervar = ExprVar(
"aListener")
deallocmanagee = MethodDefn(
MethodDecl(
p.deallocManageeMethod().name,
params=[
Decl(_protocolIdType(), pvar.name),
Decl(protocolbase, listenervar.name),
],
methodspec=MethodSpec.OVERRIDE,
)
)
if not len(p.managesStmts):
deallocmanagee.addcode(
"""
FatalError(
"unreached");
return;
"""
)
else:
switchontype = StmtSwitch(pvar)
for managee
in p.managesStmts:
manageeipdltype = managee.decl.type
# Reference counted actor types don't have corresponding
# `Dealloc` methods, as they are deallocated by releasing the
# IPDL-held reference.
if manageeipdltype.isRefcounted():
continue
case = StmtCode(
"""
${concrete}->${dealloc}(static_cast<${type}>(aListener));
return;
""",
concrete=self.concreteThis(),
dealloc=_deallocMethod(manageeipdltype, self.side),
type=_cxxBareType(ipdl.type.ActorType(manageeipdltype), self.side),
)
switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case)
switchontype.addcase(
DefaultLabel(),
StmtCode(
"""
FatalError(
"unreached");
return;
"""
),
)
deallocmanagee.addstmt(switchontype)
return methods + [deallocmanagee, Whitespace.NL]
def genShmemCreatedHandler(self):
assert self.protocol.decl.type.isToplevel()
return StmtCode(
"""
{
if (!ShmemCreated(${msgvar})) {
return MsgPayloadError;
}
return MsgProcessed;
}
""",
msgvar=self.msgvar,
)
def genShmemDestroyedHandler(self):
assert self.protocol.decl.type.isToplevel()
return StmtCode(
"""
{
if (!ShmemDestroyed(${msgvar})) {
return MsgPayloadError;
}
return MsgProcessed;
}
""",
msgvar=self.msgvar,
)
# -------------------------------------------------------------------------
# The next few functions are the crux of the IPDL code generator.
# They generate code for all the nasty work of message
# serialization/deserialization and dispatching handlers for
# received messages.
##
def concreteThis(self):
implAttr = self.protocol.implAttribute(self.side)
if implAttr ==
"virtual":
return ExprVar.THIS
if implAttr
is None:
assert self.protocol.name.startswith(
"P")
className = self.protocol.name[1:] + self.side.capitalize()
else:
assert isinstance(implAttr, ipdl.ast.StringLiteral)
className = implAttr.value
return ExprCode(
"static_cast<${className}*>(this)", className=className)
def thisCall(self, function, args):
return ExprCall(ExprSelect(self.concreteThis(),
"->", function), args=args)
def visitMessageDecl(self, md):
isctor = md.decl.type.isCtor()
isdtor = md.decl.type.isDtor()
decltype = md.decl.type
sendmethod =
None
movesendmethod =
None
promisesendmethod =
None
recvlbl, recvcase =
None,
None
def addRecvCase(lbl, case):
if decltype.isAsync():
self.asyncSwitch.addcase(lbl, case)
elif decltype.isSync():
self.syncSwitch.addcase(lbl, case)
else:
assert 0
if self.sendsMessage(md):
isasync = decltype.isAsync()
# NOTE: Don't generate helper ctors for refcounted types.
#
# Safety concerns around providing your own actor to a ctor (namely
# that the return value won't be checked, and the argument will be
# `delete`-ed) are less critical with refcounted actors, due to the
# actor being held alive by the callsite.
#
# This allows refcounted actors to not implement crashing AllocPFoo
# methods on the sending side.
if isctor
and not md.decl.type.constructedType().isRefcounted():
self.cls.addstmts([self.genHelperCtor(md), Whitespace.NL])
if isctor
and isasync:
sendmethod, (recvlbl, recvcase) = self.genAsyncCtor(md)
elif isctor:
sendmethod = self.genBlockingCtorMethod(md)
elif isdtor:
assert isasync
sendmethod, (recvlbl, recvcase) = self.genDtor(md)
elif isasync:
(
sendmethod,
movesendmethod,
promisesendmethod,
(recvlbl, recvcase),
) = self.genAsyncSendMethod(md)
else:
sendmethod, movesendmethod = self.genBlockingSendMethod(md)
# XXX figure out what to do here
if isdtor
and md.decl.type.constructedType().isToplevel():
sendmethod =
None
if sendmethod
is not None:
self.cls.addstmts([sendmethod, Whitespace.NL])
if movesendmethod
is not None:
self.cls.addstmts([movesendmethod, Whitespace.NL])
if promisesendmethod
is not None:
self.cls.addstmts([promisesendmethod, Whitespace.NL])
if recvcase
is not None:
addRecvCase(recvlbl, recvcase)
recvlbl, recvcase =
None,
None
if self.receivesMessage(md):
if isctor:
recvlbl, recvcase = self.genCtorRecvCase(md)
elif isdtor:
recvlbl, recvcase = self.genDtorRecvCase(md)
else:
recvlbl, recvcase = self.genRecvCase(md)
# XXX figure out what to do here
if isdtor
and md.decl.type.constructedType().isToplevel():
return
addRecvCase(recvlbl, recvcase)
def genAsyncCtor(self, md):
actor = md.actorDecl()
method = MethodDefn(self.makeSendMethodDecl(md))
msgvar, stmts = self.makeMessage(md, errfnSendCtor)
sendok, sendstmts = self.sendAsync(md, msgvar)
method.addcode(
"""
$*{bind}
// Build our constructor message.
$*{stmts}
// Notify the other side about the newly created actor. This can
// fail
if our manager has already been destroyed.
//
// NOTE:
If the send call fails due to toplevel channel teardown,
// the `IProtocol::ChannelSend` wrapper absorbs the error
for us,
// so we don
't tear down actors unexpectedly.
$*{sendstmts}
// Warn, destroy the actor,
and return null
if the message failed to
// send. Otherwise,
return the successfully created actor reference.
if (!${sendok}) {
NS_WARNING(
"Error sending ${actorname} constructor");
$*{destroy}
return nullptr;
}
return ${actor};
""",
bind=self.bindManagedActor(actor),
stmts=stmts,
sendstmts=sendstmts,
sendok=sendok,
destroy=self.destroyActor(
md, actor.var(), why=_DestroyReason.FailedConstructor
),
actor=actor.var(),
actorname=actor.ipdltype.protocol.name() + self.side.capitalize(),
)
lbl = CaseLabel(md.pqReplyId())
case = StmtBlock()
case.addstmt(StmtReturn(_Result.Processed))
# TODO not really sure what to do with async ctor "replies" yet.
# destroy actor if there was an error? tricky ...
return method, (lbl, case)
def genBlockingCtorMethod(self, md):
actor = md.actorDecl()
method = MethodDefn(self.makeSendMethodDecl(md))
msgvar, stmts = self.makeMessage(md, errfnSendCtor)
replyvar = self.replyvar
sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar)
replystmts = self.deserializeReply(
md,
replyvar,
self.side,
errfnSendCtor,
errfnSentinel(ExprLiteral.NULL),
)
method.addcode(
"""
$*{bind}
// Build our constructor message.
$*{stmts}
// Synchronously send the constructor message to the other side.
If
// the send fails, e.g. due to the remote side shutting down, the
// actor will be destroyed
and potentially freed.
UniquePtr<Message> ${replyvar};
$*{sendstmts}
if (!(${sendok})) {
// Warn, destroy the actor
and return null
if the message
// failed to send.
NS_WARNING(
"Error sending constructor");
$*{destroy}
return nullptr;
}
$*{replystmts}
return ${actor};
""",
bind=self.bindManagedActor(actor),
stmts=stmts,
replyvar=replyvar,
sendstmts=sendstmts,
sendok=sendok,
destroy=self.destroyActor(
md, actor.var(), why=_DestroyReason.FailedConstructor
),
replystmts=replystmts,
actor=actor.var(),
actorname=actor.ipdltype.protocol.name() + self.side.capitalize(),
)
return method
def bindManagedActor(self, actordecl, errfn=ExprLiteral.NULL, idexpr=
None):
actorproto = actordecl.ipdltype.protocol
if idexpr
is None:
setManagerArgs = [ExprVar.THIS]
else:
setManagerArgs = [ExprVar.THIS, idexpr]
return [
StmtCode(
"""
if (!${actor}) {
NS_WARNING(
"Cannot bind null ${actorname} actor");
return ${errfn};
}
if (!${actor}->SetManagerAndRegister($,{setManagerArgs})) {
NS_WARNING(
"Failed to bind ${actorname} actor");
return ${errfn};
}
""",
actor=actordecl.var(),
actorname=actorproto.name() + self.side.capitalize(),
errfn=errfn,
setManagerArgs=setManagerArgs,
container=self.protocol.managedVar(actorproto, self.side),
)
]
def genHelperCtor(self, md):
helperdecl = self.makeSendMethodDecl(md)
helperdecl.params = helperdecl.params[1:]
helper = MethodDefn(helperdecl)
helper.addstmts(
[
self.callAllocActor(md, retsems=
"out", side=self.side),
StmtReturn(
ExprCall(
ExprVar(helperdecl.name), args=md.makeCxxArgs(paramsems=
"move")
)
),
]
)
return helper
def genDtor(self, md):
actorvar = ExprVar(
"actor")
method = MethodDefn(self.makeDtorMethodDecl(md, actorvar))
method.addstmt(self.dtorPrologue(actorvar))
msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar)
sendok, sendstmts = self.sendAsync(md, msgvar, actorvar)
method.addstmts(
stmts
+ sendstmts
+ [Whitespace.NL]
+ self.dtorEpilogue(md, actorvar)
+ [StmtReturn(sendok)]
)
lbl = CaseLabel(md.pqReplyId())
case = StmtBlock()
case.addstmt(StmtReturn(_Result.Processed))
# TODO if the dtor is "inherently racy", keep the actor alive
# until the other side acks
return method, (lbl, case)
def destroyActor(self, md, actorexpr, why=_DestroyReason.Deletion):
return [
StmtCode(
"""
${actor}->ActorDisconnected(${why});
""",
actor=actorexpr,
why=why,
)
]
def dtorPrologue(self, actorexpr):
return StmtCode(
"""
if (!${actor} || !${actor}->CanSend()) {
NS_WARNING(
"Attempt to __delete__ missing or closed actor");
return false;
}
""",
actor=actorexpr,
)
def dtorEpilogue(self, md, actorexpr):
return self.destroyActor(md, actorexpr)
def genRecvAsyncReplyCase(self, md):
lbl = CaseLabel(md.pqReplyId())
case = StmtBlock()
case.addcode(
"""
${logMessage}
${profilerLabel}
return mAsyncCallbacks.GotReply(this, ${msgvar});
""",
logMessage=self.logMessage(
md, ExprAddrOf(self.msgvar),
"Received ", receiving=
True
),
profilerLabel=self.profilerLabel(md),
msgvar=self.msgvar,
)
return (lbl, case)
def genAsyncSendMethod(self, md):
decl = self.makeSendMethodDecl(md)
if "VirtualSendImpl" in md.attributes:
decl.methodspec = MethodSpec.VIRTUAL
method = MethodDefn(decl)
msgvar, stmts = self.makeMessage(md, errfnSend)
retvar, sendstmts = self.sendAsync(md, msgvar)
method.addstmts(stmts + [Whitespace.NL] + sendstmts + [StmtReturn(retvar)])
movemethod =
None
# Add the promise overload if we need one.
if md.returns:
decl = self.makeSendMethodDecl(md, promise=
True)
if "VirtualSendImpl" in md.attributes:
decl.methodspec = MethodSpec.VIRTUAL
promisemethod = MethodDefn(decl)
stmts = self.sendAsyncWithPromise(md)
promisemethod.addstmts(stmts)
(lbl, case) = self.genRecvAsyncReplyCase(md)
else:
(promisemethod, lbl, case) = (
None,
None,
None)
return method, movemethod, promisemethod, (lbl, case)
def genBlockingSendMethod(self, md):
method = MethodDefn(self.makeSendMethodDecl(md))
msgvar, serstmts = self.makeMessage(md, errfnSend)
replyvar = self.replyvar
sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar)
failif = StmtIf(ExprNot(sendok))
failif.addifstmt(StmtReturn.
FALSE)
desstmts = self.deserializeReply(
md, replyvar, self.side, errfnSend, errfnSentinel()
)
method.addstmts(
serstmts
+ [Whitespace.NL, StmtDecl(Decl(Type(
"UniquePtr<Message>"), replyvar.name))]
+ sendstmts
+ [failif]
+ desstmts
+ [Whitespace.NL, StmtReturn.
TRUE]
)
movemethod =
None
return method, movemethod
def genCtorRecvCase(self, md):
lbl = CaseLabel(md.pqMsgId())
case = StmtBlock()
actorhandle = self.handlevar
stmts = self.deserializeMessage(
md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
)
idvar, saveIdStmts = self.saveActorId(md)
case.addstmts(
stmts
+ [
StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
for r
in md.returns
]
# alloc the actor, register it under the foreign ID
+ [self.callAllocActor(md, retsems=
"in", side=self.side)]
+ self.bindManagedActor(
md.actorDecl(), errfn=_Result.ValuError, idexpr=_actorHId(actorhandle)
)
+ [Whitespace.NL]
+ saveIdStmts
+ self.invokeRecvHandler(md)
+ self.makeReply(md, errfnRecv, idvar)
+ [Whitespace.NL, StmtReturn(_Result.Processed)]
)
return lbl, case
def genDtorRecvCase(self, md):
lbl = CaseLabel(md.pqMsgId())
case = StmtBlock()
stmts = self.deserializeMessage(
md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
)
idvar, saveIdStmts = self.saveActorId(md)
case.addstmts(
stmts
+ [
StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
for r
in md.returns
]
+ self.invokeRecvHandler(md)
+ [Whitespace.NL]
+ saveIdStmts
+ self.makeReply(md, errfnRecv, routingId=idvar)
+ [Whitespace.NL]
+ self.dtorEpilogue(md, ExprVar.THIS)
+ [Whitespace.NL, StmtReturn(_Result.Processed)]
)
return lbl, case
def genRecvCase(self, md):
lbl = CaseLabel(md.pqMsgId())
case = StmtBlock()
stmts = self.deserializeMessage(
md, self.side, errfn=errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
)
idvar, saveIdStmts = self.saveActorId(md)
declstmts = [
StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
for r
in md.returns
]
if md.decl.type.isAsync()
and md.returns:
declstmts = self.makeResolver(md, errfnRecv, routingId=idvar)
case.addstmts(
stmts
+ saveIdStmts
+ declstmts
+ self.invokeRecvHandler(md)
+ [Whitespace.NL]
+ self.makeReply(md, errfnRecv, routingId=idvar)
+ [StmtReturn(_Result.Processed)]
)
return lbl, case
# helper methods
def makeMessage(self, md, errfn, fromActor=
None):
msgvar = self.msgvar
writervar = ExprVar(
"writer__")
routingId = self.protocol.routingId(fromActor)
this = fromActor
or ExprVar.THIS
stmts = (
[
StmtDecl(
Decl(Type(
"UniquePtr<IPC::Message>"), msgvar.name),
init=ExprCall(ExprVar(md.pqMsgCtorFunc()), args=[routingId]),
),
StmtDecl(
Decl(Type(
"IPC::MessageWriter"), writervar.name),
initargs=[ExprDeref(msgvar), this],
),
]
+ [Whitespace.NL]
+ [
_ParamTraits.checkedWrite(
p.ipdltype,
p.var(),
ExprAddrOf(writervar),
sentinelKey=p.name,
)
for p
in md.params
]
+ [Whitespace.NL]
+ self.setMessageFlags(md, msgvar)
)
return msgvar, stmts
def makeResolver(self, md, errfn, routingId):
if routingId
is None:
routingId = self.protocol.routingId()
if not md.decl.type.isAsync()
or not md.hasReply():
return []
def paramValue(idx):
assert idx < len(md.returns)
if len(md.returns) > 1:
return ExprCode(
"std::get<${idx}>(aParam)", idx=idx)
return ExprVar(
"aParam")
serializeParams = [
_ParamTraits.checkedWrite(
p.ipdltype,
paramValue(idx),
ExprAddrOf(ExprVar(
"writer__")),
sentinelKey=p.name,
)
for idx, p
in enumerate(md.returns)
]
return [
StmtCode(
"""
UniquePtr<IPC::Message> ${replyvar}(${replyCtor}(${routingId}));
${replyvar}->set_seqno(${msgvar}.seqno());
RefPtr<mozilla::ipc::IPDLResolverInner> resolver__ =
new mozilla::ipc::IPDLResolverInner(std::move(${replyvar}), this);
${resolvertype} resolver = [resolver__ = std::move(resolver__)](${resolveType} aParam)
{
resolver__->Resolve([&] (IPC::Message* ${replyvar}, IProtocol* self__) {
IPC::MessageWriter writer__(*${replyvar}, self__);
$*{serializeParams}
${logSendingReply}
});
};
""",
msgvar=self.msgvar,
resolvertype=Type(md.resolverName()),
routingId=routingId,
resolveType=_resolveType(md.returns, self.side),
replyvar=self.replyvar,
replyCtor=ExprVar(md.pqReplyCtorFunc()),
serializeParams=serializeParams,
logSendingReply=self.logMessage(
md,
self.replyvar,
"Sending reply ",
actor=ExprVar("self__"),
),
)
]
def makeReply(self, md, errfn, routingId):
if routingId is None:
routingId = self.protocol.routingId()
# TODO special cases for async ctor/dtor replies
if not md.decl.type.hasReply():
return []
if md.decl.type.isAsync() and md.decl.type.hasReply():
return []
replyvar = self.replyvar
return (
[
StmtExpr(
ExprAssn(
replyvar,
ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[routingId]),
)
),
StmtDecl(
Decl(Type("IPC::MessageWriter"), "writer__"),
initargs=[ExprDeref(replyvar), ExprVar.THIS],
),
Whitespace.NL,
]
+ [
_ParamTraits.checkedWrite(
r.ipdltype,
r.var(),
ExprAddrOf(ExprVar("writer__")),
sentinelKey=r.name,
)
for r in md.returns
]
+ self.setMessageFlags(md, replyvar)
+ [self.logMessage(md, replyvar, "Sending reply ")]
)
def setMessageFlags(self, md, var, seqno=None):
stmts = []
if seqno:
stmts.append(
StmtExpr(ExprCall(ExprSelect(var, "->", "set_seqno"), args=[seqno]))
)
return stmts + [Whitespace.NL]
def deserializeMessage(self, md, side, errfn, errfnSent):
msgvar = self.msgvar
msgexpr = ExprAddrOf(msgvar)
readervar = ExprVar("reader__")
isctor = md.decl.type.isCtor()
stmts = [
self.logMessage(md, msgexpr, "Received ", receiving=True),
self.profilerLabel(md),
Whitespace.NL,
]
if 0 == len(md.params):
return stmts
start, reads = 0, []
if isctor:
# return the raw actor handle so that its ID can be used
# to construct the "real" actor
handlevar = self.handlevar
handletype = Type("ActorHandle")
reads = [
_ParamTraits.checkedRead(
None,
handletype,
handlevar,
ExprAddrOf(readervar),
errfn,
"'%s'" % handletype.name,
sentinelKey="actor",
errfnSentinel=errfnSent,
)
]
start = 1
def maybeTainted(p, side):
if md.decl.type.tainted and "NoTaint" not in p.attributes:
return Type("Tainted", T=p.bareType(side))
return p.bareType(side)
reads.extend(
[
_ParamTraits.checkedRead(
p.ipdltype,
maybeTainted(p, side),
p.var(),
ExprAddrOf(readervar),
errfn,
"'%s'" % p.ipdltype.name(),
sentinelKey=p.name,
errfnSentinel=errfnSent,
)
for p in md.params[start:]
]
)
stmts.extend(
(
[
StmtDecl(
Decl(Type("IPC::MessageReader"), readervar.name),
initargs=[msgvar, ExprVar.THIS],
)
]
+ [Whitespace.NL]
+ reads
+ [StmtCode("${reader}.EndRead();\n", reader=readervar)]
)
)
return stmts
def asyncReplyCallbackImpl(self, md, side, resolvecb):
assert md.returns
readervar = ExprVar("aReader")
# Custom error function handler which calls FatalError on the reader
def errfn(msg, errcode=_Result.ValuError):
return [
StmtExpr(
ExprCall(
ExprSelect(readervar, "->", "FatalError"),
args=[ExprLiteral.String(msg)],
)
),
StmtReturn(errcode),
]
start, reads = 0, []
if md.decl.type.isCtor():
# return the raw actor handle so that its ID can be used
# to construct the "real" actor
handlevar = self.handlevar
handletype = Type("ActorHandle")
reads = [
_ParamTraits.checkedRead(
None,
handletype,
handlevar,
readervar,
errfn,
"'%s'" % handletype.name,
sentinelKey="actor",
errfnSentinel=errfnSentinel(_Result.ValuError),
)
]
start = 1
reads += [
_ParamTraits.checkedRead(
p.ipdltype,
p.bareType(side),
p.var(),
readervar,
errfn,
"'%s'" % p.ipdltype.name(),
sentinelKey=p.name,
errfnSentinel=errfnSentinel(_Result.ValuError),
)
for p in md.returns[start:]
]
if len(md.returns) > 1:
resolvearg = ExprCall(
ExprVar("std::make_tuple"),
args=[ExprMove(p.var()) for p in md.returns],
)
else:
resolvearg = ExprMove(md.returns[0].var())
return ExprCode(
"""
[resolve=std::move(${resolvecb})](IPC::MessageReader* aReader) mutable {
$*{reads}
aReader->EndRead();
resolve(${resolvearg});
return MsgProcessed;
}
""",
resolvecb=resolvecb,
reads=reads,
resolvearg=resolvearg,
)
def deserializeReply(self, md, replyexpr, side, errfn, errfnSentinel, actor=None):
stmts = [
Whitespace.NL,
self.logMessage(md, replyexpr, "Received reply ", actor, receiving=True),
]
if 0 == len(md.returns):
return stmts
def tempvar(r):
return ExprVar(r.var().name + "__reply")
readervar = ExprVar("reader__")
stmts.extend(
[
Whitespace.NL,
StmtDecl(
Decl(Type("IPC::MessageReader"), readervar.name),
initargs=[ExprDeref(self.replyvar), ExprVar.THIS],
),
]
+ [Whitespace.NL]
+ [
_ParamTraits.checkedRead(
r.ipdltype,
r.bareType(side),
tempvar(r),
ExprAddrOf(readervar),
errfn,
"'%s'" % r.ipdltype.name(),
sentinelKey=r.name,
errfnSentinel=errfnSentinel,
)
for r in md.returns
]
# Move-assign the values out of the variables created with
# checkedRead into outparams.
+ [
StmtExpr(ExprAssn(ExprDeref(r.var()), ExprMove(tempvar(r))))
for r in md.returns
]
+ [StmtCode("${reader}.EndRead();", reader=readervar)]
)
return stmts
def sendAsync(self, md, msgexpr, actor=None):
sendok = ExprVar("sendok__")
seqno = ExprVar("seqno__")
stmts = [
Whitespace.NL,
self.logMessage(md, msgexpr, "Sending ", actor),
self.profilerLabel(md),
]
stmts.append(Whitespace.NL)
# Generate the actual call expression.
send = ExprVar("ChannelSend")
if actor is not None:
send = ExprSelect(actor, "->", send.name)
if md.returns:
stmts.append(StmtDecl(Decl(Type.INT32, seqno.name), init=ExprLiteral.ZERO))
ifsendok = StmtIf(
ExprCall(send, args=[ExprMove(msgexpr), ExprAddrOf(seqno)])
)
stmts.append(ifsendok)
callback = self.asyncReplyCallbackImpl(md, self.side, ExprVar("aResolve"))
ifsendok.addifstmt(
StmtCode(
"""
mAsyncCallbacks.AddCallback(${seqno}, ${replyid}, ${callback}, std::move(aReject));
""",
seqno=seqno,
replyid=md.pqReplyId(),
callback=callback,
)
)
ifsendok.addelsestmt(
StmtCode(
"""
aReject(::mozilla::ipc::ResponseRejectReason::SendError);
"""
)
)
retvar = None
else:
stmts.append(
StmtDecl(
Decl(Type.BOOL, sendok.name),
init=ExprCall(send, args=[ExprMove(msgexpr)]),
)
)
retvar = sendok
return (retvar, stmts)
def sendBlocking(self, md, msgexpr, replyexpr, actor=None):
send = ExprVar("ChannelSend")
if actor is not None:
send = ExprSelect(actor, "->", send.name)
sendok = ExprVar("sendok__")
self.externalIncludes.add("mozilla/ProfilerMarkers.h")
return (
sendok,
(
[
Whitespace.NL,
self.logMessage(md, msgexpr, "Sending ", actor),
self.profilerLabel(md),
]
+ [
Whitespace.NL,
StmtDecl(Decl(Type.BOOL, sendok.name), init=ExprLiteral.FALSE),
StmtBlock(
[
StmtExpr(
ExprCall(
ExprVar("AUTO_PROFILER_TRACING_MARKER"),
[
ExprLiteral.String("Sync IPC"),
ExprLiteral.String(
self.protocol.name
+ "::"
+ md.prettyMsgName()
),
ExprVar("IPC"),
],
)
),
StmtExpr(
ExprAssn(
sendok,
ExprCall(
send,
args=[ExprMove(msgexpr), ExprAddrOf(replyexpr)],
),
)
),
]
),
]
),
)
def sendAsyncWithPromise(self, md):
# Create a new promise, and forward to the callback send overload.
promise = _makePromise(md.returns, self.side, resolver=True)
if len(md.returns) > 1:
resolvetype = _tuple([d.bareType(self.side) for d in md.returns])
else:
resolvetype = md.returns[0].bareType(self.side)
resolve = ExprCode(
"""
[promise__](${resolvetype}&& aValue) {
promise__->Resolve(std::move(aValue), __func__);
}
""",
resolvetype=resolvetype,
)
reject = ExprCode(
"""
[promise__](ResponseRejectReason&& aReason) {
promise__->Reject(std::move(aReason), __func__);
}
""",
resolvetype=resolvetype,
)
args = [ExprMove(p.var()) for p in md.params] + [resolve, reject]
stmt = StmtCode(
"""
RefPtr<${promise}> promise__ = new ${promise}(__func__);
promise__->UseDirectTaskDispatch(__func__);
${send}($,{args});
return promise__;
""",
promise=promise,
send=md.sendMethod(),
args=args,
)
return [stmt]
def callAllocActor(self, md, retsems, side):
actortype = md.actorDecl().bareType(self.side)
if md.decl.type.constructedType().isRefcounted():
actortype.ptr = False
actortype = _refptr(actortype)
callalloc = self.thisCall(
_allocMethod(md.decl.type.constructedType(), side),
args=md.makeCxxArgs(retsems=retsems, retcallsems="out", implicit=False),
)
return StmtDecl(Decl(actortype, md.actorDecl().var().name), init=callalloc)
def invokeRecvHandler(self, md):
retsems = "in"
if md.decl.type.isAsync() and md.returns:
retsems = "resolver"
okdecl = StmtDecl(
Decl(Type("mozilla::ipc::IPCResult"), "__ok"),
init=self.thisCall(
md.recvMethod(),
md.makeCxxArgs(
paramsems="move",
retsems=retsems,
retcallsems="out",
),
),
)
failif = StmtIf(ExprNot(ExprVar("__ok")))
failif.addifstmts(
[
_protocolErrorBreakpoint("Handler returned error code!"),
Whitespace(
"// Error handled in mozilla::ipc::IPCResult\n", indent=True
),
StmtReturn(_Result.ProcessingError),
]
)
return [okdecl, failif]
def makeDtorMethodDecl(self, md, actorvar):
decl = self.makeSendMethodDecl(md)
decl.params.insert(
0,
Decl(
_cxxInType(
ipdl.type.ActorType(md.decl.type.constructedType()),
side=self.side,
direction="send",
),
actorvar.name,
),
)
decl.methodspec = MethodSpec.STATIC
return decl
def makeSendMethodDecl(self, md, promise=False, paramsems="in"):
implicit = md.decl.type.hasImplicitActorParam()
if md.decl.type.isAsync() and md.returns:
if promise:
returnsems = "promise"
rettype = _refptr(Type(md.promiseName()))
else:
returnsems = "callback"
rettype = Type.VOID
else:
assert not promise
returnsems = "out"
rettype = Type.BOOL
decl = MethodDecl(
md.sendMethod(),
params=md.makeCxxParams(
paramsems,
returnsems=returnsems,
side=self.side,
implicit=implicit,
direction="send",
),
warn_unused=(
(self.side == "parent" and returnsems != "callback")
or (md.decl.type.isCtor() and not md.decl.type.isAsync())
),
ret=rettype,
)
if md.decl.type.isCtor():
decl.ret = md.actorDecl().bareType(self.side)
return decl
def logMessage(self, md, msgptr, pfx, actor=None, receiving=False):
actorname = _actorName(self.protocol.name, self.side)
return StmtCode(
"""
if (mozilla::ipc::LoggingEnabledFor(${protocolname}, ${side})) {
mozilla::ipc::LogMessageForProtocol(
${actorname},
${actor}->ToplevelProtocol()->OtherPidMaybeInvalid(),
${pfx},
${msgptr}->type(),
mozilla::ipc::MessageDirection::${direction});
}
""",
protocolname=ExprLiteral.String(self.protocol.name),
side=_cxxSide(self.side),
actorname=ExprLiteral.String(actorname),
actor=actor or ExprVar.THIS,
pfx=ExprLiteral.String(pfx),
msgptr=msgptr,
direction="eReceiving" if receiving else "eSending",
)
def profilerLabel(self, md):
self.externalIncludes.add("mozilla/ProfilerLabels.h")
return StmtCode(
"""
AUTO_PROFILER_LABEL("${name}::${msgname}", OTHER);
""",
name=self.protocol.name,
msgname=md.prettyMsgName(),
)
def saveActorId(self, md):
idvar = ExprVar("id__")
if md.decl.type.hasReply():
# only save the ID if we're actually going to use it, to
# avoid unused-variable warnings
saveIdStmts = [
StmtDecl(Decl(_actorIdType(), idvar.name), self.protocol.routingId())
]
else:
saveIdStmts = []
return idvar, saveIdStmts
class _GenerateProtocolParentCode(_GenerateProtocolActorCode):
def __init__(self):
_GenerateProtocolActorCode.__init__(self, "parent")
def sendsMessage(self, md):
return not md.decl.type.isIn()
def receivesMessage(self, md):
return md.decl.type.isInout() or md.decl.type.isIn()
class _GenerateProtocolChildCode(_GenerateProtocolActorCode):
def __init__(self):
_GenerateProtocolActorCode.__init__(self, "child")
def sendsMessage(self, md):
return not md.decl.type.isOut()
def receivesMessage(self, md):
return md.decl.type.isInout() or md.decl.type.isOut()
# -----------------------------------------------------------------------------
# Utility passes
##
def _splitClassDeclDefn(cls):
"""Destructively split |cls| methods into declarations and
definitions (if |not methodDecl.force_inline|). Return classDecl,
methodDefns."""
defns = Block()
for i, stmt in enumerate(cls.stmts):
if isinstance(stmt, MethodDefn) and not stmt.decl.force_inline:
decl, defn = _splitMethodDeclDefn(stmt, cls)
cls.stmts[i] = StmtDecl(decl)
if defn:
defns.addstmts([defn, Whitespace.NL])
return cls, defns
def _splitMethodDeclDefn(md, cls):
# Pure methods have decls but no defns.
if md.decl.methodspec == MethodSpec.PURE:
return md.decl, None
saveddecl = deepcopy(md.decl)
md.decl.cls = cls
# Don't emit method specifiers on method defns.
md.decl.methodspec = MethodSpec.NONE
md.decl.warn_unused = False
md.decl.only_for_definition = True
for param in md.decl.params:
if isinstance(param, Param):
param.default = None
return saveddecl, md
def _splitFuncDeclDefn(fun):
assert not fun.decl.force_inline
return StmtDecl(fun.decl), fun