Mini Shell

Direktori : /proc/self/root/opt/imh-python/lib/python3.9/site-packages/twisted/test/
Upload File :
Current File : //proc/self/root/opt/imh-python/lib/python3.9/site-packages/twisted/test/test_amp.py

# Copyright (c) 2005 Divmod, Inc.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

"""
Tests for L{twisted.protocols.amp}.
"""


import datetime
import decimal
from typing import Dict, Type
from unittest import skipIf

from zope.interface import implementer
from zope.interface.verify import verifyClass, verifyObject

from twisted.internet import address, defer, error, interfaces, protocol, reactor
from twisted.protocols import amp
from twisted.python import filepath
from twisted.python.failure import Failure
from twisted.test import iosim
from twisted.test.proto_helpers import StringTransport
from twisted.trial.unittest import TestCase

try:
    from twisted.internet import ssl as _ssl
except ImportError:
    ssl = None
else:
    if not _ssl.supported:
        ssl = None
    else:
        ssl = _ssl

if ssl is None:
    skipSSL = True
else:
    skipSSL = False

if not interfaces.IReactorSSL.providedBy(reactor):
    reactorLacksSSL = True
else:
    reactorLacksSSL = False


tz = amp._FixedOffsetTZInfo.fromSignHoursMinutes


class TestProto(protocol.Protocol):
    """
    A trivial protocol for use in testing where a L{Protocol} is expected.

    @ivar instanceId: the id of this instance
    @ivar onConnLost: deferred that will fired when the connection is lost
    @ivar dataToSend: data to send on the protocol
    """

    instanceCount = 0

    def __init__(self, onConnLost, dataToSend):
        assert isinstance(dataToSend, bytes), repr(dataToSend)
        self.onConnLost = onConnLost
        self.dataToSend = dataToSend
        self.instanceId = TestProto.instanceCount
        TestProto.instanceCount = TestProto.instanceCount + 1

    def connectionMade(self):
        self.data = []
        self.transport.write(self.dataToSend)

    def dataReceived(self, bytes):
        self.data.append(bytes)

    def connectionLost(self, reason):
        self.onConnLost.callback(self.data)

    def __repr__(self) -> str:
        """
        Custom repr for testing to avoid coupling amp tests with repr from
        L{Protocol}

        Returns a string which contains a unique identifier that can be looked
        up using the instanceId property::

            <TestProto #3>
        """
        return "<TestProto #%d>" % (self.instanceId,)


class SimpleSymmetricProtocol(amp.AMP):
    def sendHello(self, text):
        return self.callRemoteString(b"hello", hello=text)

    def amp_HELLO(self, box):
        return amp.Box(hello=box[b"hello"])


class UnfriendlyGreeting(Exception):
    """Greeting was insufficiently kind."""


class DeathThreat(Exception):
    """Greeting was insufficiently kind."""


class UnknownProtocol(Exception):
    """Asked to switch to the wrong protocol."""


class TransportPeer(amp.Argument):
    # this serves as some informal documentation for how to get variables from
    # the protocol or your environment and pass them to methods as arguments.
    def retrieve(self, d, name, proto):
        return b""

    def fromStringProto(self, notAString, proto):
        return proto.transport.getPeer()

    def toBox(self, name, strings, objects, proto):
        return


class Hello(amp.Command):

    commandName = b"hello"

    arguments = [
        (b"hello", amp.String()),
        (b"optional", amp.Boolean(optional=True)),
        (b"print", amp.Unicode(optional=True)),
        (b"from", TransportPeer(optional=True)),
        (b"mixedCase", amp.String(optional=True)),
        (b"dash-arg", amp.String(optional=True)),
        (b"underscore_arg", amp.String(optional=True)),
    ]

    response = [(b"hello", amp.String()), (b"print", amp.Unicode(optional=True))]

    errors: Dict[Type[Exception], bytes] = {UnfriendlyGreeting: b"UNFRIENDLY"}

    fatalErrors: Dict[Type[Exception], bytes] = {DeathThreat: b"DEAD"}


class NoAnswerHello(Hello):
    commandName = Hello.commandName
    requiresAnswer = False


class FutureHello(amp.Command):
    commandName = b"hello"

    arguments = [
        (b"hello", amp.String()),
        (b"optional", amp.Boolean(optional=True)),
        (b"print", amp.Unicode(optional=True)),
        (b"from", TransportPeer(optional=True)),
        (b"bonus", amp.String(optional=True)),  # addt'l arguments
        # should generally be
        # added at the end, and
        # be optional...
    ]

    response = [(b"hello", amp.String()), (b"print", amp.Unicode(optional=True))]

    errors = {UnfriendlyGreeting: b"UNFRIENDLY"}


class WTF(amp.Command):
    """
    An example of an invalid command.
    """


class BrokenReturn(amp.Command):
    """An example of a perfectly good command, but the handler is going to return
    None...
    """

    commandName = b"broken_return"


class Goodbye(amp.Command):
    # commandName left blank on purpose: this tests implicit command names.
    response = [(b"goodbye", amp.String())]
    responseType = amp.QuitBox


class WaitForever(amp.Command):
    commandName = b"wait_forever"


class GetList(amp.Command):
    commandName = b"getlist"
    arguments = [(b"length", amp.Integer())]
    response = [(b"body", amp.AmpList([(b"x", amp.Integer())]))]


class DontRejectMe(amp.Command):
    commandName = b"dontrejectme"
    arguments = [
        (b"magicWord", amp.Unicode()),
        (b"list", amp.AmpList([(b"name", amp.Unicode())], optional=True)),
    ]
    response = [(b"response", amp.Unicode())]


class SecuredPing(amp.Command):
    # XXX TODO: actually make this refuse to send over an insecure connection
    response = [(b"pinged", amp.Boolean())]


class TestSwitchProto(amp.ProtocolSwitchCommand):
    commandName = b"Switch-Proto"

    arguments = [
        (b"name", amp.String()),
    ]
    errors = {UnknownProtocol: b"UNKNOWN"}


class SingleUseFactory(protocol.ClientFactory):
    def __init__(self, proto):
        self.proto = proto
        self.proto.factory = self

    def buildProtocol(self, addr):
        p, self.proto = self.proto, None
        return p

    reasonFailed = None

    def clientConnectionFailed(self, connector, reason):
        self.reasonFailed = reason
        return


THING_I_DONT_UNDERSTAND = b"gwebol nargo"


class ThingIDontUnderstandError(Exception):
    pass


class FactoryNotifier(amp.AMP):
    factory = None

    def connectionMade(self):
        if self.factory is not None:
            self.factory.theProto = self
            if hasattr(self.factory, "onMade"):
                self.factory.onMade.callback(None)

    def emitpong(self):
        from twisted.internet.interfaces import ISSLTransport

        if not ISSLTransport.providedBy(self.transport):
            raise DeathThreat("only send secure pings over secure channels")
        return {"pinged": True}

    SecuredPing.responder(emitpong)


class SimpleSymmetricCommandProtocol(FactoryNotifier):
    maybeLater = None

    def __init__(self, onConnLost=None):
        amp.AMP.__init__(self)
        self.onConnLost = onConnLost

    def sendHello(self, text):
        return self.callRemote(Hello, hello=text)

    def sendUnicodeHello(self, text, translation):
        return self.callRemote(Hello, hello=text, Print=translation)

    greeted = False

    def cmdHello(
        self,
        hello,
        From,
        optional=None,
        Print=None,
        mixedCase=None,
        dash_arg=None,
        underscore_arg=None,
    ):
        assert From == self.transport.getPeer()
        if hello == THING_I_DONT_UNDERSTAND:
            raise ThingIDontUnderstandError()
        if hello.startswith(b"fuck"):
            raise UnfriendlyGreeting("Don't be a dick.")
        if hello == b"die":
            raise DeathThreat("aieeeeeeeee")
        result = dict(hello=hello)
        if Print is not None:
            result.update(dict(Print=Print))
        self.greeted = True
        return result

    Hello.responder(cmdHello)

    def cmdGetlist(self, length):
        return {"body": [dict(x=1)] * length}

    GetList.responder(cmdGetlist)

    def okiwont(self, magicWord, list=None):
        if list is None:
            response = "list omitted"
        else:
            response = "%s accepted" % (list[0]["name"])
        return dict(response=response)

    DontRejectMe.responder(okiwont)

    def waitforit(self):
        self.waiting = defer.Deferred()
        return self.waiting

    WaitForever.responder(waitforit)

    def saybye(self):
        return dict(goodbye=b"everyone")

    Goodbye.responder(saybye)

    def switchToTestProtocol(self, fail=False):
        if fail:
            name = b"no-proto"
        else:
            name = b"test-proto"
        p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA)
        return self.callRemote(
            TestSwitchProto, SingleUseFactory(p), name=name
        ).addCallback(lambda ign: p)

    def switchit(self, name):
        if name == b"test-proto":
            return TestProto(self.onConnLost, SWITCH_SERVER_DATA)
        raise UnknownProtocol(name)

    TestSwitchProto.responder(switchit)

    def donothing(self):
        return None

    BrokenReturn.responder(donothing)


class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol):
    def switchit(self, name):
        if name == b"test-proto":
            self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA)
            self.maybeLater = defer.Deferred()
            return self.maybeLater

    TestSwitchProto.responder(switchit)


class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
    def badResponder(
        self,
        hello,
        From,
        optional=None,
        Print=None,
        mixedCase=None,
        dash_arg=None,
        underscore_arg=None,
    ):
        """
        This responder does nothing and forgets to return a dictionary.
        """

    NoAnswerHello.responder(badResponder)


class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
    def goodNoAnswerResponder(
        self,
        hello,
        From,
        optional=None,
        Print=None,
        mixedCase=None,
        dash_arg=None,
        underscore_arg=None,
    ):
        return dict(hello=hello + b"-noanswer")

    NoAnswerHello.responder(goodNoAnswerResponder)


def connectedServerAndClient(
    ServerClass=SimpleSymmetricProtocol, ClientClass=SimpleSymmetricProtocol, *a, **kw
):
    """Returns a 3-tuple: (client, server, pump)"""
    return iosim.connectedServerAndClient(ServerClass, ClientClass, *a, **kw)


class TotallyDumbProtocol(protocol.Protocol):
    buf = b""

    def dataReceived(self, data):
        self.buf += data


class LiteralAmp(amp.AMP):
    def __init__(self):
        self.boxes = []

    def ampBoxReceived(self, box):
        self.boxes.append(box)
        return


class AmpBoxTests(TestCase):
    """
    Test a few essential properties of AMP boxes, mostly with respect to
    serialization correctness.
    """

    def test_serializeStr(self):
        """
        Make sure that strs serialize to strs.
        """
        a = amp.AmpBox(key=b"value")
        self.assertEqual(type(a.serialize()), bytes)

    def test_serializeUnicodeKeyRaises(self):
        """
        Verify that TypeError is raised when trying to serialize Unicode keys.
        """
        a = amp.AmpBox(**{"key": "value"})
        self.assertRaises(TypeError, a.serialize)

    def test_serializeUnicodeValueRaises(self):
        """
        Verify that TypeError is raised when trying to serialize Unicode
        values.
        """
        a = amp.AmpBox(key="value")
        self.assertRaises(TypeError, a.serialize)


class ParsingTests(TestCase):
    def test_booleanValues(self):
        """
        Verify that the Boolean parser parses 'True' and 'False', but nothing
        else.
        """
        b = amp.Boolean()
        self.assertTrue(b.fromString(b"True"))
        self.assertFalse(b.fromString(b"False"))
        self.assertRaises(TypeError, b.fromString, b"ninja")
        self.assertRaises(TypeError, b.fromString, b"true")
        self.assertRaises(TypeError, b.fromString, b"TRUE")
        self.assertEqual(b.toString(True), b"True")
        self.assertEqual(b.toString(False), b"False")

    def test_pathValueRoundTrip(self):
        """
        Verify the 'Path' argument can parse and emit a file path.
        """
        fp = filepath.FilePath(self.mktemp())
        p = amp.Path()
        s = p.toString(fp)
        v = p.fromString(s)
        self.assertIsNot(fp, v)  # sanity check
        self.assertEqual(fp, v)

    def test_sillyEmptyThing(self):
        """
        Test that empty boxes raise an error; they aren't supposed to be sent
        on purpose.
        """
        a = amp.AMP()
        return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box())

    def test_ParsingRoundTrip(self):
        """
        Verify that various kinds of data make it through the encode/parse
        round-trip unharmed.
        """
        c, s, p = connectedServerAndClient(
            ClientClass=LiteralAmp, ServerClass=LiteralAmp
        )

        SIMPLE = (b"simple", b"test")
        CE = (b"ceq", b": ")
        CR = (b"crtest", b"test\r")
        LF = (b"lftest", b"hello\n")
        NEWLINE = (b"newline", b"test\r\none\r\ntwo")
        NEWLINE2 = (b"newline2", b"test\r\none\r\n two")
        BODYTEST = (b"body", b"blah\r\n\r\ntesttest")

        testData = [
            [SIMPLE],
            [SIMPLE, BODYTEST],
            [SIMPLE, CE],
            [SIMPLE, CR],
            [SIMPLE, CE, CR, LF],
            [CE, CR, LF],
            [SIMPLE, NEWLINE, CE, NEWLINE2],
            [BODYTEST, SIMPLE, NEWLINE],
        ]

        for test in testData:
            jb = amp.Box()
            jb.update(dict(test))
            jb._sendTo(c)
            p.flush()
            self.assertEqual(s.boxes[-1], jb)


class FakeLocator:
    """
    This is a fake implementation of the interface implied by
    L{CommandLocator}.
    """

    def __init__(self):
        """
        Remember the given keyword arguments as a set of responders.
        """
        self.commands = {}

    def locateResponder(self, commandName):
        """
        Look up and return a function passed as a keyword argument of the given
        name to the constructor.
        """
        return self.commands[commandName]


class FakeSender:
    """
    This is a fake implementation of the 'box sender' interface implied by
    L{AMP}.
    """

    def __init__(self):
        """
        Create a fake sender and initialize the list of received boxes and
        unhandled errors.
        """
        self.sentBoxes = []
        self.unhandledErrors = []
        self.expectedErrors = 0

    def expectError(self):
        """
        Expect one error, so that the test doesn't fail.
        """
        self.expectedErrors += 1

    def sendBox(self, box):
        """
        Accept a box, but don't do anything.
        """
        self.sentBoxes.append(box)

    def unhandledError(self, failure):
        """
        Deal with failures by instantly re-raising them for easier debugging.
        """
        self.expectedErrors -= 1
        if self.expectedErrors < 0:
            failure.raiseException()
        else:
            self.unhandledErrors.append(failure)


class CommandDispatchTests(TestCase):
    """
    The AMP CommandDispatcher class dispatches converts AMP boxes into commands
    and responses using Command.responder decorator.

    Note: Originally, AMP's factoring was such that many tests for this
    functionality are now implemented as full round-trip tests in L{AMPTests}.
    Future tests should be written at this level instead, to ensure API
    compatibility and to provide more granular, readable units of test
    coverage.
    """

    def setUp(self):
        """
        Create a dispatcher to use.
        """
        self.locator = FakeLocator()
        self.sender = FakeSender()
        self.dispatcher = amp.BoxDispatcher(self.locator)
        self.dispatcher.startReceivingBoxes(self.sender)

    def test_receivedAsk(self):
        """
        L{CommandDispatcher.ampBoxReceived} should locate the appropriate
        command in its responder lookup, based on the '_ask' key.
        """
        received = []

        def thunk(box):
            received.append(box)
            return amp.Box({"hello": "goodbye"})

        input = amp.Box(_command="hello", _ask="test-command-id", hello="world")
        self.locator.commands["hello"] = thunk
        self.dispatcher.ampBoxReceived(input)
        self.assertEqual(received, [input])

    def test_sendUnhandledError(self):
        """
        L{CommandDispatcher} should relay its unhandled errors in responding to
        boxes to its boxSender.
        """
        err = RuntimeError("something went wrong, oh no")
        self.sender.expectError()
        self.dispatcher.unhandledError(Failure(err))
        self.assertEqual(len(self.sender.unhandledErrors), 1)
        self.assertEqual(self.sender.unhandledErrors[0].value, err)

    def test_unhandledSerializationError(self):
        """
        Errors during serialization ought to be relayed to the sender's
        unhandledError method.
        """
        err = RuntimeError("something undefined went wrong")

        def thunk(result):
            class BrokenBox(amp.Box):
                def _sendTo(self, proto):
                    raise err

            return BrokenBox()

        self.locator.commands["hello"] = thunk
        input = amp.Box(_command="hello", _ask="test-command-id", hello="world")
        self.sender.expectError()
        self.dispatcher.ampBoxReceived(input)
        self.assertEqual(len(self.sender.unhandledErrors), 1)
        self.assertEqual(self.sender.unhandledErrors[0].value, err)

    def test_callRemote(self):
        """
        L{CommandDispatcher.callRemote} should emit a properly formatted '_ask'
        box to its boxSender and record an outstanding L{Deferred}.  When a
        corresponding '_answer' packet is received, the L{Deferred} should be
        fired, and the results translated via the given L{Command}'s response
        de-serialization.
        """
        D = self.dispatcher.callRemote(Hello, hello=b"world")
        self.assertEqual(
            self.sender.sentBoxes,
            [amp.AmpBox(_command=b"hello", _ask=b"1", hello=b"world")],
        )
        answers = []
        D.addCallback(answers.append)
        self.assertEqual(answers, [])
        self.dispatcher.ampBoxReceived(
            amp.AmpBox({b"hello": b"yay", b"print": b"ignored", b"_answer": b"1"})
        )
        self.assertEqual(answers, [dict(hello=b"yay", Print="ignored")])

    def _localCallbackErrorLoggingTest(self, callResult):
        """
        Verify that C{callResult} completes with a L{None} result and that an
        unhandled error has been logged.
        """
        finalResult = []
        callResult.addBoth(finalResult.append)

        self.assertEqual(1, len(self.sender.unhandledErrors))
        self.assertIsInstance(self.sender.unhandledErrors[0].value, ZeroDivisionError)

        self.assertEqual([None], finalResult)

    def test_callRemoteSuccessLocalCallbackErrorLogging(self):
        """
        If the last callback on the L{Deferred} returned by C{callRemote} (added
        by application code calling C{callRemote}) fails, the failure is passed
        to the sender's C{unhandledError} method.
        """
        self.sender.expectError()

        callResult = self.dispatcher.callRemote(Hello, hello=b"world")
        callResult.addCallback(lambda result: 1 // 0)

        self.dispatcher.ampBoxReceived(
            amp.AmpBox({b"hello": b"yay", b"print": b"ignored", b"_answer": b"1"})
        )

        self._localCallbackErrorLoggingTest(callResult)

    def test_callRemoteErrorLocalCallbackErrorLogging(self):
        """
        Like L{test_callRemoteSuccessLocalCallbackErrorLogging}, but for the
        case where the L{Deferred} returned by C{callRemote} fails.
        """
        self.sender.expectError()

        callResult = self.dispatcher.callRemote(Hello, hello=b"world")
        callResult.addErrback(lambda result: 1 // 0)

        self.dispatcher.ampBoxReceived(
            amp.AmpBox(
                {
                    b"_error": b"1",
                    b"_error_code": b"bugs",
                    b"_error_description": b"stuff",
                }
            )
        )

        self._localCallbackErrorLoggingTest(callResult)


class SimpleGreeting(amp.Command):
    """
    A very simple greeting command that uses a few basic argument types.
    """

    commandName = b"simple"
    arguments = [(b"greeting", amp.Unicode()), (b"cookie", amp.Integer())]
    response = [(b"cookieplus", amp.Integer())]


class TestLocator(amp.CommandLocator):
    """
    A locator which implements a responder to the 'simple' command.
    """

    def __init__(self):
        self.greetings = []

    def greetingResponder(self, greeting, cookie):
        self.greetings.append((greeting, cookie))
        return dict(cookieplus=cookie + 3)

    greetingResponder = SimpleGreeting.responder(greetingResponder)


class OverridingLocator(TestLocator):
    """
    A locator which overrides the responder to the 'simple' command.
    """

    def greetingResponder(self, greeting, cookie):
        """
        Return a different cookieplus than L{TestLocator.greetingResponder}.
        """
        self.greetings.append((greeting, cookie))
        return dict(cookieplus=cookie + 4)

    greetingResponder = SimpleGreeting.responder(greetingResponder)


class InheritingLocator(OverridingLocator):
    """
    This locator should inherit the responder from L{OverridingLocator}.
    """


class OverrideLocatorAMP(amp.AMP):
    def __init__(self):
        amp.AMP.__init__(self)
        self.customResponder = object()
        self.expectations = {b"custom": self.customResponder}
        self.greetings = []

    def lookupFunction(self, name):
        """
        Override the deprecated lookupFunction function.
        """
        if name in self.expectations:
            result = self.expectations[name]
            return result
        else:
            return super().lookupFunction(name)

    def greetingResponder(self, greeting, cookie):
        self.greetings.append((greeting, cookie))
        return dict(cookieplus=cookie + 3)

    greetingResponder = SimpleGreeting.responder(greetingResponder)


class CommandLocatorTests(TestCase):
    """
    The CommandLocator should enable users to specify responders to commands as
    functions that take structured objects, annotated with metadata.
    """

    def _checkSimpleGreeting(self, locatorClass, expected):
        """
        Check that a locator of type C{locatorClass} finds a responder
        for command named I{simple} and that the found responder answers
        with the C{expected} result to a C{SimpleGreeting<"ni hao", 5>}
        command.
        """
        locator = locatorClass()
        responderCallable = locator.locateResponder(b"simple")
        result = responderCallable(amp.Box(greeting=b"ni hao", cookie=b"5"))

        def done(values):
            self.assertEqual(values, amp.AmpBox(cookieplus=b"%d" % (expected,)))

        return result.addCallback(done)

    def test_responderDecorator(self):
        """
        A method on a L{CommandLocator} subclass decorated with a L{Command}
        subclass's L{responder} decorator should be returned from
        locateResponder, wrapped in logic to serialize and deserialize its
        arguments.
        """
        return self._checkSimpleGreeting(TestLocator, 8)

    def test_responderOverriding(self):
        """
        L{CommandLocator} subclasses can override a responder inherited from
        a base class by using the L{Command.responder} decorator to register
        a new responder method.
        """
        return self._checkSimpleGreeting(OverridingLocator, 9)

    def test_responderInheritance(self):
        """
        Responder lookup follows the same rules as normal method lookup
        rules, particularly with respect to inheritance.
        """
        return self._checkSimpleGreeting(InheritingLocator, 9)

    def test_lookupFunctionDeprecatedOverride(self):
        """
        Subclasses which override locateResponder under its old name,
        lookupFunction, should have the override invoked instead.  (This tests
        an AMP subclass, because in the version of the code that could invoke
        this deprecated code path, there was no L{CommandLocator}.)
        """
        locator = OverrideLocatorAMP()
        customResponderObject = self.assertWarns(
            PendingDeprecationWarning,
            "Override locateResponder, not lookupFunction.",
            __file__,
            lambda: locator.locateResponder(b"custom"),
        )
        self.assertEqual(locator.customResponder, customResponderObject)
        # Make sure upcalling works too
        normalResponderObject = self.assertWarns(
            PendingDeprecationWarning,
            "Override locateResponder, not lookupFunction.",
            __file__,
            lambda: locator.locateResponder(b"simple"),
        )
        result = normalResponderObject(amp.Box(greeting=b"ni hao", cookie=b"5"))

        def done(values):
            self.assertEqual(values, amp.AmpBox(cookieplus=b"8"))

        return result.addCallback(done)

    def test_lookupFunctionDeprecatedInvoke(self):
        """
        Invoking locateResponder under its old name, lookupFunction, should
        emit a deprecation warning, but do the same thing.
        """
        locator = TestLocator()
        responderCallable = self.assertWarns(
            PendingDeprecationWarning,
            "Call locateResponder, not lookupFunction.",
            __file__,
            lambda: locator.lookupFunction(b"simple"),
        )
        result = responderCallable(amp.Box(greeting=b"ni hao", cookie=b"5"))

        def done(values):
            self.assertEqual(values, amp.AmpBox(cookieplus=b"8"))

        return result.addCallback(done)


SWITCH_CLIENT_DATA = b"Success!"
SWITCH_SERVER_DATA = b"No, really.  Success."


class BinaryProtocolTests(TestCase):
    """
    Tests for L{amp.BinaryBoxProtocol}.

    @ivar _boxSender: After C{startReceivingBoxes} is called, the L{IBoxSender}
        which was passed to it.
    """

    def setUp(self):
        """
        Keep track of all boxes received by this test in its capacity as an
        L{IBoxReceiver} implementor.
        """
        self.boxes = []
        self.data = []

    def startReceivingBoxes(self, sender):
        """
        Implement L{IBoxReceiver.startReceivingBoxes} to just remember the
        value passed in.
        """
        self._boxSender = sender

    def ampBoxReceived(self, box):
        """
        A box was received by the protocol.
        """
        self.boxes.append(box)

    stopReason = None

    def stopReceivingBoxes(self, reason):
        """
        Record the reason that we stopped receiving boxes.
        """
        self.stopReason = reason

    # fake ITransport
    def getPeer(self):
        return "no peer"

    def getHost(self):
        return "no host"

    def write(self, data):
        self.assertIsInstance(data, bytes)
        self.data.append(data)

    def test_startReceivingBoxes(self):
        """
        When L{amp.BinaryBoxProtocol} is connected to a transport, it calls
        C{startReceivingBoxes} on its L{IBoxReceiver} with itself as the
        L{IBoxSender} parameter.
        """
        protocol = amp.BinaryBoxProtocol(self)
        protocol.makeConnection(None)
        self.assertIs(self._boxSender, protocol)

    def test_sendBoxInStartReceivingBoxes(self):
        """
        The L{IBoxReceiver} which is started when L{amp.BinaryBoxProtocol} is
        connected to a transport can call C{sendBox} on the L{IBoxSender}
        passed to it before C{startReceivingBoxes} returns and have that box
        sent.
        """

        class SynchronouslySendingReceiver:
            def startReceivingBoxes(self, sender):
                sender.sendBox(amp.Box({b"foo": b"bar"}))

        transport = StringTransport()
        protocol = amp.BinaryBoxProtocol(SynchronouslySendingReceiver())
        protocol.makeConnection(transport)
        self.assertEqual(transport.value(), b"\x00\x03foo\x00\x03bar\x00\x00")

    def test_receiveBoxStateMachine(self):
        """
        When a binary box protocol receives:
            * a key
            * a value
            * an empty string
        it should emit a box and send it to its boxReceiver.
        """
        a = amp.BinaryBoxProtocol(self)
        a.stringReceived(b"hello")
        a.stringReceived(b"world")
        a.stringReceived(b"")
        self.assertEqual(self.boxes, [amp.AmpBox(hello=b"world")])

    def test_firstBoxFirstKeyExcessiveLength(self):
        """
        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
        the first a key it receives is larger than 255.
        """
        transport = StringTransport()
        protocol = amp.BinaryBoxProtocol(self)
        protocol.makeConnection(transport)
        protocol.dataReceived(b"\x01\x00")
        self.assertTrue(transport.disconnecting)

    def test_firstBoxSubsequentKeyExcessiveLength(self):
        """
        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
        a subsequent key in the first box it receives is larger than 255.
        """
        transport = StringTransport()
        protocol = amp.BinaryBoxProtocol(self)
        protocol.makeConnection(transport)
        protocol.dataReceived(b"\x00\x01k\x00\x01v")
        self.assertFalse(transport.disconnecting)
        protocol.dataReceived(b"\x01\x00")
        self.assertTrue(transport.disconnecting)

    def test_subsequentBoxFirstKeyExcessiveLength(self):
        """
        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
        the first key in a subsequent box it receives is larger than 255.
        """
        transport = StringTransport()
        protocol = amp.BinaryBoxProtocol(self)
        protocol.makeConnection(transport)
        protocol.dataReceived(b"\x00\x01k\x00\x01v\x00\x00")
        self.assertFalse(transport.disconnecting)
        protocol.dataReceived(b"\x01\x00")
        self.assertTrue(transport.disconnecting)

    def test_excessiveKeyFailure(self):
        """
        If L{amp.BinaryBoxProtocol} disconnects because it received a key
        length prefix which was too large, the L{IBoxReceiver}'s
        C{stopReceivingBoxes} method is called with a L{TooLong} failure.
        """
        protocol = amp.BinaryBoxProtocol(self)
        protocol.makeConnection(StringTransport())
        protocol.dataReceived(b"\x01\x00")
        protocol.connectionLost(
            Failure(error.ConnectionDone("simulated connection done"))
        )
        self.stopReason.trap(amp.TooLong)
        self.assertTrue(self.stopReason.value.isKey)
        self.assertFalse(self.stopReason.value.isLocal)
        self.assertIsNone(self.stopReason.value.value)
        self.assertIsNone(self.stopReason.value.keyName)

    def test_unhandledErrorWithTransport(self):
        """
        L{amp.BinaryBoxProtocol.unhandledError} logs the failure passed to it
        and disconnects its transport.
        """
        transport = StringTransport()
        protocol = amp.BinaryBoxProtocol(self)
        protocol.makeConnection(transport)
        protocol.unhandledError(Failure(RuntimeError("Fake error")))
        self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError)))
        self.assertTrue(transport.disconnecting)

    def test_unhandledErrorWithoutTransport(self):
        """
        L{amp.BinaryBoxProtocol.unhandledError} completes without error when
        there is no associated transport.
        """
        protocol = amp.BinaryBoxProtocol(self)
        protocol.makeConnection(StringTransport())
        protocol.connectionLost(Failure(Exception("Simulated")))
        protocol.unhandledError(Failure(RuntimeError("Fake error")))
        self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError)))

    def test_receiveBoxData(self):
        """
        When a binary box protocol receives the serialized form of an AMP box,
        it should emit a similar box to its boxReceiver.
        """
        a = amp.BinaryBoxProtocol(self)
        a.dataReceived(
            amp.Box(
                {b"testKey": b"valueTest", b"anotherKey": b"anotherValue"}
            ).serialize()
        )
        self.assertEqual(
            self.boxes,
            [amp.Box({b"testKey": b"valueTest", b"anotherKey": b"anotherValue"})],
        )

    def test_receiveLongerBoxData(self):
        """
        An L{amp.BinaryBoxProtocol} can receive serialized AMP boxes with
        values of up to (2 ** 16 - 1) bytes.
        """
        length = 2 ** 16 - 1
        value = b"x" * length
        transport = StringTransport()
        protocol = amp.BinaryBoxProtocol(self)
        protocol.makeConnection(transport)
        protocol.dataReceived(amp.Box({"k": value}).serialize())
        self.assertEqual(self.boxes, [amp.Box({"k": value})])
        self.assertFalse(transport.disconnecting)

    def test_sendBox(self):
        """
        When a binary box protocol sends a box, it should emit the serialized
        bytes of that box to its transport.
        """
        a = amp.BinaryBoxProtocol(self)
        a.makeConnection(self)
        aBox = amp.Box({b"testKey": b"valueTest", b"someData": b"hello"})
        a.makeConnection(self)
        a.sendBox(aBox)
        self.assertEqual(b"".join(self.data), aBox.serialize())

    def test_connectionLostStopSendingBoxes(self):
        """
        When a binary box protocol loses its connection, it should notify its
        box receiver that it has stopped receiving boxes.
        """
        a = amp.BinaryBoxProtocol(self)
        a.makeConnection(self)
        connectionFailure = Failure(RuntimeError())
        a.connectionLost(connectionFailure)
        self.assertIs(self.stopReason, connectionFailure)

    def test_protocolSwitch(self):
        """
        L{BinaryBoxProtocol} has the capacity to switch to a different protocol
        on a box boundary.  When a protocol is in the process of switching, it
        cannot receive traffic.
        """
        otherProto = TestProto(None, b"outgoing data")
        test = self

        class SwitchyReceiver:
            switched = False

            def startReceivingBoxes(self, sender):
                pass

            def ampBoxReceived(self, box):
                test.assertFalse(self.switched, "Should only receive one box!")
                self.switched = True
                a._lockForSwitch()
                a._switchTo(otherProto)

        a = amp.BinaryBoxProtocol(SwitchyReceiver())
        anyOldBox = amp.Box({b"include": b"lots", b"of": b"data"})
        a.makeConnection(self)
        # Include a 0-length box at the beginning of the next protocol's data,
        # to make sure that AMP doesn't eat the data or try to deliver extra
        # boxes either...
        moreThanOneBox = anyOldBox.serialize() + b"\x00\x00Hello, world!"
        a.dataReceived(moreThanOneBox)
        self.assertIs(otherProto.transport, self)
        self.assertEqual(b"".join(otherProto.data), b"\x00\x00Hello, world!")
        self.assertEqual(self.data, [b"outgoing data"])
        a.dataReceived(b"more data")
        self.assertEqual(b"".join(otherProto.data), b"\x00\x00Hello, world!more data")
        self.assertRaises(amp.ProtocolSwitched, a.sendBox, anyOldBox)

    def test_protocolSwitchEmptyBuffer(self):
        """
        After switching to a different protocol, if no extra bytes beyond
        the switch box were delivered, an empty string is not passed to the
        switched protocol's C{dataReceived} method.
        """
        a = amp.BinaryBoxProtocol(self)
        a.makeConnection(self)
        otherProto = TestProto(None, b"")
        a._switchTo(otherProto)
        self.assertEqual(otherProto.data, [])

    def test_protocolSwitchInvalidStates(self):
        """
        In order to make sure the protocol never gets any invalid data sent
        into the middle of a box, it must be locked for switching before it is
        switched.  It can only be unlocked if the switch failed, and attempting
        to send a box while it is locked should raise an exception.
        """
        a = amp.BinaryBoxProtocol(self)
        a.makeConnection(self)
        sampleBox = amp.Box({b"some": b"data"})
        a._lockForSwitch()
        self.assertRaises(amp.ProtocolSwitched, a.sendBox, sampleBox)
        a._unlockFromSwitch()
        a.sendBox(sampleBox)
        self.assertEqual(b"".join(self.data), sampleBox.serialize())
        a._lockForSwitch()
        otherProto = TestProto(None, b"outgoing data")
        a._switchTo(otherProto)
        self.assertRaises(amp.ProtocolSwitched, a._unlockFromSwitch)

    def test_protocolSwitchLoseConnection(self):
        """
        When the protocol is switched, it should notify its nested protocol of
        disconnection.
        """

        class Loser(protocol.Protocol):
            reason = None

            def connectionLost(self, reason):
                self.reason = reason

        connectionLoser = Loser()
        a = amp.BinaryBoxProtocol(self)
        a.makeConnection(self)
        a._lockForSwitch()
        a._switchTo(connectionLoser)
        connectionFailure = Failure(RuntimeError())
        a.connectionLost(connectionFailure)
        self.assertEqual(connectionLoser.reason, connectionFailure)

    def test_protocolSwitchLoseClientConnection(self):
        """
        When the protocol is switched, it should notify its nested client
        protocol factory of disconnection.
        """

        class ClientLoser:
            reason = None

            def clientConnectionLost(self, connector, reason):
                self.reason = reason

        a = amp.BinaryBoxProtocol(self)
        connectionLoser = protocol.Protocol()
        clientLoser = ClientLoser()
        a.makeConnection(self)
        a._lockForSwitch()
        a._switchTo(connectionLoser, clientLoser)
        connectionFailure = Failure(RuntimeError())
        a.connectionLost(connectionFailure)
        self.assertEqual(clientLoser.reason, connectionFailure)


class AMPTests(TestCase):
    def test_interfaceDeclarations(self):
        """
        The classes in the amp module ought to implement the interfaces that
        are declared for their benefit.
        """
        for interface, implementation in [
            (amp.IBoxSender, amp.BinaryBoxProtocol),
            (amp.IBoxReceiver, amp.BoxDispatcher),
            (amp.IResponderLocator, amp.CommandLocator),
            (amp.IResponderLocator, amp.SimpleStringLocator),
            (amp.IBoxSender, amp.AMP),
            (amp.IBoxReceiver, amp.AMP),
            (amp.IResponderLocator, amp.AMP),
        ]:
            self.assertTrue(
                interface.implementedBy(implementation),
                f"{implementation} does not implements({interface})",
            )

    def test_helloWorld(self):
        """
        Verify that a simple command can be sent and its response received with
        the simple low-level string-based API.
        """
        c, s, p = connectedServerAndClient()
        L = []
        HELLO = b"world"
        c.sendHello(HELLO).addCallback(L.append)
        p.flush()
        self.assertEqual(L[0][b"hello"], HELLO)

    def test_wireFormatRoundTrip(self):
        """
        Verify that mixed-case, underscored and dashed arguments are mapped to
        their python names properly.
        """
        c, s, p = connectedServerAndClient()
        L = []
        HELLO = b"world"
        c.sendHello(HELLO).addCallback(L.append)
        p.flush()
        self.assertEqual(L[0][b"hello"], HELLO)

    def test_helloWorldUnicode(self):
        """
        Verify that unicode arguments can be encoded and decoded.
        """
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        L = []
        HELLO = b"world"
        HELLO_UNICODE = "wor\u1234ld"
        c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append)
        p.flush()
        self.assertEqual(L[0]["hello"], HELLO)
        self.assertEqual(L[0]["Print"], HELLO_UNICODE)

    def test_callRemoteStringRequiresAnswerFalse(self):
        """
        L{BoxDispatcher.callRemoteString} returns L{None} if C{requiresAnswer}
        is C{False}.
        """
        c, s, p = connectedServerAndClient()
        ret = c.callRemoteString(b"WTF", requiresAnswer=False)
        self.assertIsNone(ret)

    def test_unknownCommandLow(self):
        """
        Verify that unknown commands using low-level APIs will be rejected with an
        error, but will NOT terminate the connection.
        """
        c, s, p = connectedServerAndClient()
        L = []

        def clearAndAdd(e):
            """
            You can't propagate the error...
            """
            e.trap(amp.UnhandledCommand)
            return "OK"

        c.callRemoteString(b"WTF").addErrback(clearAndAdd).addCallback(L.append)
        p.flush()
        self.assertEqual(L.pop(), "OK")
        HELLO = b"world"
        c.sendHello(HELLO).addCallback(L.append)
        p.flush()
        self.assertEqual(L[0][b"hello"], HELLO)

    def test_unknownCommandHigh(self):
        """
        Verify that unknown commands using high-level APIs will be rejected with an
        error, but will NOT terminate the connection.
        """
        c, s, p = connectedServerAndClient()
        L = []

        def clearAndAdd(e):
            """
            You can't propagate the error...
            """
            e.trap(amp.UnhandledCommand)
            return "OK"

        c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append)
        p.flush()
        self.assertEqual(L.pop(), "OK")
        HELLO = b"world"
        c.sendHello(HELLO).addCallback(L.append)
        p.flush()
        self.assertEqual(L[0][b"hello"], HELLO)

    def test_brokenReturnValue(self):
        """
        It can be very confusing if you write some code which responds to a
        command, but gets the return value wrong.  Most commonly you end up
        returning None instead of a dictionary.

        Verify that if that happens, the framework logs a useful error.
        """
        L = []
        SimpleSymmetricCommandProtocol().dispatchCommand(
            amp.AmpBox(_command=BrokenReturn.commandName)
        ).addErrback(L.append)
        L[0].trap(amp.BadLocalReturn)
        self.failUnlessIn("None", repr(L[0].value))

    def test_unknownArgument(self):
        """
        Verify that unknown arguments are ignored, and not passed to a Python
        function which can't accept them.
        """
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        L = []
        HELLO = b"world"
        # c.sendHello(HELLO).addCallback(L.append)
        c.callRemote(
            FutureHello, hello=HELLO, bonus=b"I'm not in the book!"
        ).addCallback(L.append)
        p.flush()
        self.assertEqual(L[0]["hello"], HELLO)

    def test_simpleReprs(self):
        """
        Verify that the various Box objects repr properly, for debugging.
        """
        self.assertEqual(type(repr(amp._SwitchBox("a"))), str)
        self.assertEqual(type(repr(amp.QuitBox())), str)
        self.assertEqual(type(repr(amp.AmpBox())), str)
        self.assertIn("AmpBox", repr(amp.AmpBox()))

    def test_innerProtocolInRepr(self):
        """
        Verify that L{AMP} objects output their innerProtocol when set.
        """
        otherProto = TestProto(None, b"outgoing data")
        a = amp.AMP()
        a.innerProtocol = otherProto

        self.assertEqual(
            repr(a),
            "<AMP inner <TestProto #%d> at 0x%x>" % (otherProto.instanceId, id(a)),
        )

    def test_innerProtocolNotInRepr(self):
        """
        Verify that L{AMP} objects do not output 'inner' when no innerProtocol
        is set.
        """
        a = amp.AMP()
        self.assertEqual(repr(a), f"<AMP at 0x{id(a):x}>")

    @skipIf(skipSSL, "SSL not available")
    def test_simpleSSLRepr(self):
        """
        L{amp._TLSBox.__repr__} returns a string.
        """
        self.assertEqual(type(repr(amp._TLSBox())), str)

    def test_keyTooLong(self):
        """
        Verify that a key that is too long will immediately raise a synchronous
        exception.
        """
        c, s, p = connectedServerAndClient()
        x = "H" * (0xFF + 1)
        tl = self.assertRaises(amp.TooLong, c.callRemoteString, b"Hello", **{x: b"hi"})
        self.assertTrue(tl.isKey)
        self.assertTrue(tl.isLocal)
        self.assertIsNone(tl.keyName)
        self.assertEqual(tl.value, x.encode("ascii"))
        self.assertIn(str(len(x)), repr(tl))
        self.assertIn("key", repr(tl))

    def test_valueTooLong(self):
        """
        Verify that attempting to send value longer than 64k will immediately
        raise an exception.
        """
        c, s, p = connectedServerAndClient()
        x = b"H" * (0xFFFF + 1)
        tl = self.assertRaises(amp.TooLong, c.sendHello, x)
        p.flush()
        self.assertFalse(tl.isKey)
        self.assertTrue(tl.isLocal)
        self.assertEqual(tl.keyName, b"hello")
        self.failUnlessIdentical(tl.value, x)
        self.assertIn(str(len(x)), repr(tl))
        self.assertIn("value", repr(tl))
        self.assertIn("hello", repr(tl))

    def test_helloWorldCommand(self):
        """
        Verify that a simple command can be sent and its response received with
        the high-level value parsing API.
        """
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        L = []
        HELLO = b"world"
        c.sendHello(HELLO).addCallback(L.append)
        p.flush()
        self.assertEqual(L[0]["hello"], HELLO)

    def test_helloErrorHandling(self):
        """
        Verify that if a known error type is raised and handled, it will be
        properly relayed to the other end of the connection and translated into
        an exception, and no error will be logged.
        """
        L = []
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        HELLO = b"fuck you"
        c.sendHello(HELLO).addErrback(L.append)
        p.flush()
        L[0].trap(UnfriendlyGreeting)
        self.assertEqual(str(L[0].value), "Don't be a dick.")

    def test_helloFatalErrorHandling(self):
        """
        Verify that if a known, fatal error type is raised and handled, it will
        be properly relayed to the other end of the connection and translated
        into an exception, no error will be logged, and the connection will be
        terminated.
        """
        L = []
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        HELLO = b"die"
        c.sendHello(HELLO).addErrback(L.append)
        p.flush()
        L.pop().trap(DeathThreat)
        c.sendHello(HELLO).addErrback(L.append)
        p.flush()
        L.pop().trap(error.ConnectionDone)

    def test_helloNoErrorHandling(self):
        """
        Verify that if an unknown error type is raised, it will be relayed to
        the other end of the connection and translated into an exception, it
        will be logged, and then the connection will be dropped.
        """
        L = []
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        HELLO = THING_I_DONT_UNDERSTAND
        c.sendHello(HELLO).addErrback(L.append)
        p.flush()
        ure = L.pop()
        ure.trap(amp.UnknownRemoteError)
        c.sendHello(HELLO).addErrback(L.append)
        cl = L.pop()
        cl.trap(error.ConnectionDone)
        # The exception should have been logged.
        self.assertTrue(self.flushLoggedErrors(ThingIDontUnderstandError))

    def test_lateAnswer(self):
        """
        Verify that a command that does not get answered until after the
        connection terminates will not cause any errors.
        """
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        L = []
        c.callRemote(WaitForever).addErrback(L.append)
        p.flush()
        self.assertEqual(L, [])
        s.transport.loseConnection()
        p.flush()
        L.pop().trap(error.ConnectionDone)
        # Just make sure that it doesn't error...
        s.waiting.callback({})
        return s.waiting

    def test_requiresNoAnswer(self):
        """
        Verify that a command that requires no answer is run.
        """
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        HELLO = b"world"
        c.callRemote(NoAnswerHello, hello=HELLO)
        p.flush()
        self.assertTrue(s.greeted)

    def test_requiresNoAnswerFail(self):
        """
        Verify that commands sent after a failed no-answer request do not complete.
        """
        L = []
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        HELLO = b"fuck you"
        c.callRemote(NoAnswerHello, hello=HELLO)
        p.flush()
        # This should be logged locally.
        self.assertTrue(self.flushLoggedErrors(amp.RemoteAmpError))
        HELLO = b"world"
        c.callRemote(Hello, hello=HELLO).addErrback(L.append)
        p.flush()
        L.pop().trap(error.ConnectionDone)
        self.assertFalse(s.greeted)

    def test_requiresNoAnswerAfterFail(self):
        """
        No-answer commands sent after the connection has been torn down do not
        return a L{Deferred}.
        """
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        c.transport.loseConnection()
        p.flush()
        result = c.callRemote(NoAnswerHello, hello=b"ignored")
        self.assertIs(result, None)

    def test_noAnswerResponderBadAnswer(self):
        """
        Verify that responders of requiresAnswer=False commands have to return
        a dictionary anyway.

        (requiresAnswer is a hint from the _client_ - the server may be called
        upon to answer commands in any case, if the client wants to know when
        they complete.)
        """
        c, s, p = connectedServerAndClient(
            ServerClass=BadNoAnswerCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        c.callRemote(NoAnswerHello, hello=b"hello")
        p.flush()
        le = self.flushLoggedErrors(amp.BadLocalReturn)
        self.assertEqual(len(le), 1)

    def test_noAnswerResponderAskedForAnswer(self):
        """
        Verify that responders with requiresAnswer=False will actually respond
        if the client sets requiresAnswer=True.  In other words, verify that
        requiresAnswer is a hint honored only by the client.
        """
        c, s, p = connectedServerAndClient(
            ServerClass=NoAnswerCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        L = []
        c.callRemote(Hello, hello=b"Hello!").addCallback(L.append)
        p.flush()
        self.assertEqual(len(L), 1)
        self.assertEqual(
            L, [dict(hello=b"Hello!-noanswer", Print=None)]
        )  # Optional response argument

    def test_ampListCommand(self):
        """
        Test encoding of an argument that uses the AmpList encoding.
        """
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        L = []
        c.callRemote(GetList, length=10).addCallback(L.append)
        p.flush()
        values = L.pop().get("body")
        self.assertEqual(values, [{"x": 1}] * 10)

    def test_optionalAmpListOmitted(self):
        """
        Sending a command with an omitted AmpList argument that is
        designated as optional does not raise an InvalidSignature error.
        """
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        L = []
        c.callRemote(DontRejectMe, magicWord="please").addCallback(L.append)
        p.flush()
        response = L.pop().get("response")
        self.assertEqual(response, "list omitted")

    def test_optionalAmpListPresent(self):
        """
        Sanity check that optional AmpList arguments are processed normally.
        """
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )
        L = []
        c.callRemote(
            DontRejectMe, magicWord="please", list=[{"name": "foo"}]
        ).addCallback(L.append)
        p.flush()
        response = L.pop().get("response")
        self.assertEqual(response, "foo accepted")

    def test_failEarlyOnArgSending(self):
        """
        Verify that if we pass an invalid argument list (omitting an argument),
        an exception will be raised.
        """
        self.assertRaises(amp.InvalidSignature, Hello)

    def test_doubleProtocolSwitch(self):
        """
        As a debugging aid, a protocol system should raise a
        L{ProtocolSwitched} exception when asked to switch a protocol that is
        already switched.
        """
        serverDeferred = defer.Deferred()
        serverProto = SimpleSymmetricCommandProtocol(serverDeferred)
        clientDeferred = defer.Deferred()
        clientProto = SimpleSymmetricCommandProtocol(clientDeferred)
        c, s, p = connectedServerAndClient(
            ServerClass=lambda: serverProto, ClientClass=lambda: clientProto
        )

        def switched(result):
            self.assertRaises(amp.ProtocolSwitched, c.switchToTestProtocol)
            self.testSucceeded = True

        c.switchToTestProtocol().addCallback(switched)
        p.flush()
        self.assertTrue(self.testSucceeded)

    def test_protocolSwitch(
        self,
        switcher=SimpleSymmetricCommandProtocol,
        spuriousTraffic=False,
        spuriousError=False,
    ):
        """
        Verify that it is possible to switch to another protocol mid-connection and
        send data to it successfully.
        """
        self.testSucceeded = False

        serverDeferred = defer.Deferred()
        serverProto = switcher(serverDeferred)
        clientDeferred = defer.Deferred()
        clientProto = switcher(clientDeferred)
        c, s, p = connectedServerAndClient(
            ServerClass=lambda: serverProto, ClientClass=lambda: clientProto
        )

        if spuriousTraffic:
            wfdr = []  # remote
            c.callRemote(WaitForever).addErrback(wfdr.append)
        switchDeferred = c.switchToTestProtocol()
        if spuriousTraffic:
            self.assertRaises(amp.ProtocolSwitched, c.sendHello, b"world")

        def cbConnsLost(info):
            ((serverSuccess, serverData), (clientSuccess, clientData)) = info
            self.assertTrue(serverSuccess)
            self.assertTrue(clientSuccess)
            self.assertEqual(b"".join(serverData), SWITCH_CLIENT_DATA)
            self.assertEqual(b"".join(clientData), SWITCH_SERVER_DATA)
            self.testSucceeded = True

        def cbSwitch(proto):
            return defer.DeferredList([serverDeferred, clientDeferred]).addCallback(
                cbConnsLost
            )

        switchDeferred.addCallback(cbSwitch)
        p.flush()
        if serverProto.maybeLater is not None:
            serverProto.maybeLater.callback(serverProto.maybeLaterProto)
            p.flush()
        if spuriousTraffic:
            # switch is done here; do this here to make sure that if we're
            # going to corrupt the connection, we do it before it's closed.
            if spuriousError:
                s.waiting.errback(
                    amp.RemoteAmpError(
                        b"SPURIOUS", "Here's some traffic in the form of an error."
                    )
                )
            else:
                s.waiting.callback({})
            p.flush()
        c.transport.loseConnection()  # close it
        p.flush()
        self.assertTrue(self.testSucceeded)

    def test_protocolSwitchDeferred(self):
        """
        Verify that protocol-switching even works if the value returned from
        the command that does the switch is deferred.
        """
        return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtocol)

    def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol):
        """
        Verify that if we try to switch protocols and it fails, the connection
        stays up and we can go back to speaking AMP.
        """
        self.testSucceeded = False

        serverDeferred = defer.Deferred()
        serverProto = switcher(serverDeferred)
        clientDeferred = defer.Deferred()
        clientProto = switcher(clientDeferred)
        c, s, p = connectedServerAndClient(
            ServerClass=lambda: serverProto, ClientClass=lambda: clientProto
        )
        L = []
        c.switchToTestProtocol(fail=True).addErrback(L.append)
        p.flush()
        L.pop().trap(UnknownProtocol)
        self.assertFalse(self.testSucceeded)
        # It's a known error, so let's send a "hello" on the same connection;
        # it should work.
        c.sendHello(b"world").addCallback(L.append)
        p.flush()
        self.assertEqual(L.pop()["hello"], b"world")

    def test_trafficAfterSwitch(self):
        """
        Verify that attempts to send traffic after a switch will not corrupt
        the nested protocol.
        """
        return self.test_protocolSwitch(spuriousTraffic=True)

    def test_errorAfterSwitch(self):
        """
        Returning an error after a protocol switch should record the underlying
        error.
        """
        return self.test_protocolSwitch(spuriousTraffic=True, spuriousError=True)

    def test_quitBoxQuits(self):
        """
        Verify that commands with a responseType of QuitBox will in fact
        terminate the connection.
        """
        c, s, p = connectedServerAndClient(
            ServerClass=SimpleSymmetricCommandProtocol,
            ClientClass=SimpleSymmetricCommandProtocol,
        )

        L = []
        HELLO = b"world"
        GOODBYE = b"everyone"
        c.sendHello(HELLO).addCallback(L.append)
        p.flush()
        self.assertEqual(L.pop()["hello"], HELLO)
        c.callRemote(Goodbye).addCallback(L.append)
        p.flush()
        self.assertEqual(L.pop()["goodbye"], GOODBYE)
        c.sendHello(HELLO).addErrback(L.append)
        L.pop().trap(error.ConnectionDone)

    def test_basicLiteralEmit(self):
        """
        Verify that the command dictionaries for a callRemoteN look correct
        after being serialized and parsed.
        """
        c, s, p = connectedServerAndClient()
        L = []
        s.ampBoxReceived = L.append
        c.callRemote(
            Hello,
            hello=b"hello test",
            mixedCase=b"mixed case arg test",
            dash_arg=b"x",
            underscore_arg=b"y",
        )
        p.flush()
        self.assertEqual(len(L), 1)
        for k, v in [
            (b"_command", Hello.commandName),
            (b"hello", b"hello test"),
            (b"mixedCase", b"mixed case arg test"),
            (b"dash-arg", b"x"),
            (b"underscore_arg", b"y"),
        ]:
            self.assertEqual(L[-1].pop(k), v)
        L[-1].pop(b"_ask")
        self.assertEqual(L[-1], {})

    def test_basicStructuredEmit(self):
        """
        Verify that a call similar to basicLiteralEmit's is handled properly with
        high-level quoting and passing to Python methods, and that argument
        names are correctly handled.
        """
        L = []

        class StructuredHello(amp.AMP):
            def h(self, *a, **k):
                L.append((a, k))
                return dict(hello=b"aaa")

            Hello.responder(h)

        c, s, p = connectedServerAndClient(ServerClass=StructuredHello)
        c.callRemote(
            Hello,
            hello=b"hello test",
            mixedCase=b"mixed case arg test",
            dash_arg=b"x",
            underscore_arg=b"y",
        ).addCallback(L.append)
        p.flush()
        self.assertEqual(len(L), 2)
        self.assertEqual(
            L[0],
            (
                (),
                dict(
                    hello=b"hello test",
                    mixedCase=b"mixed case arg test",
                    dash_arg=b"x",
                    underscore_arg=b"y",
                    From=s.transport.getPeer(),
                    # XXX - should optional arguments just not be passed?
                    # passing None seems a little odd, looking at the way it
                    # turns out here... -glyph
                    Print=None,
                    optional=None,
                ),
            ),
        )
        self.assertEqual(L[1], dict(Print=None, hello=b"aaa"))


class PretendRemoteCertificateAuthority:
    def checkIsPretendRemote(self):
        return True


class IOSimCert:
    verifyCount = 0

    def options(self, *ign):
        return self

    def iosimVerify(self, otherCert):
        """
        This isn't a real certificate, and wouldn't work on a real socket, but
        iosim specifies a different API so that we don't have to do any crypto
        math to demonstrate that the right functions get called in the right
        places.
        """
        assert otherCert is self
        self.verifyCount += 1
        return True


class OKCert(IOSimCert):
    def options(self, x):
        assert x.checkIsPretendRemote()
        return self


class GrumpyCert(IOSimCert):
    def iosimVerify(self, otherCert):
        self.verifyCount += 1
        return False


class DroppyCert(IOSimCert):
    def __init__(self, toDrop):
        self.toDrop = toDrop

    def iosimVerify(self, otherCert):
        self.verifyCount += 1
        self.toDrop.loseConnection()
        return True


class SecurableProto(FactoryNotifier):

    factory = None

    def verifyFactory(self):
        return [PretendRemoteCertificateAuthority()]

    def getTLSVars(self):
        cert = self.certFactory()
        verify = self.verifyFactory()
        return dict(tls_localCertificate=cert, tls_verifyAuthorities=verify)

    amp.StartTLS.responder(getTLSVars)


@skipIf(skipSSL, "SSL not available")
@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
class TLSTests(TestCase):
    def test_startingTLS(self):
        """
        Verify that starting TLS and succeeding at handshaking sends all the
        notifications to all the right places.
        """
        cli, svr, p = connectedServerAndClient(
            ServerClass=SecurableProto, ClientClass=SecurableProto
        )

        okc = OKCert()
        svr.certFactory = lambda: okc

        cli.callRemote(
            amp.StartTLS,
            tls_localCertificate=okc,
            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
        )

        # let's buffer something to be delivered securely
        L = []
        cli.callRemote(SecuredPing).addCallback(L.append)
        p.flush()
        # once for client once for server
        self.assertEqual(okc.verifyCount, 2)
        L = []
        cli.callRemote(SecuredPing).addCallback(L.append)
        p.flush()
        self.assertEqual(L[0], {"pinged": True})

    def test_startTooManyTimes(self):
        """
        Verify that the protocol will complain if we attempt to renegotiate TLS,
        which we don't support.
        """
        cli, svr, p = connectedServerAndClient(
            ServerClass=SecurableProto, ClientClass=SecurableProto
        )

        okc = OKCert()
        svr.certFactory = lambda: okc

        cli.callRemote(
            amp.StartTLS,
            tls_localCertificate=okc,
            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
        )
        p.flush()
        cli.noPeerCertificate = True  # this is totally fake
        self.assertRaises(
            amp.OnlyOneTLS,
            cli.callRemote,
            amp.StartTLS,
            tls_localCertificate=okc,
            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
        )

    def test_negotiationFailed(self):
        """
        Verify that starting TLS and failing on both sides at handshaking sends
        notifications to all the right places and terminates the connection.
        """

        badCert = GrumpyCert()

        cli, svr, p = connectedServerAndClient(
            ServerClass=SecurableProto, ClientClass=SecurableProto
        )
        svr.certFactory = lambda: badCert

        cli.callRemote(amp.StartTLS, tls_localCertificate=badCert)

        p.flush()
        # once for client once for server - but both fail
        self.assertEqual(badCert.verifyCount, 2)
        d = cli.callRemote(SecuredPing)
        p.flush()
        self.assertFailure(d, iosim.NativeOpenSSLError)

    def test_negotiationFailedByClosing(self):
        """
        Verify that starting TLS and failing by way of a lost connection
        notices that it is probably an SSL problem.
        """

        cli, svr, p = connectedServerAndClient(
            ServerClass=SecurableProto, ClientClass=SecurableProto
        )
        droppyCert = DroppyCert(svr.transport)
        svr.certFactory = lambda: droppyCert

        cli.callRemote(amp.StartTLS, tls_localCertificate=droppyCert)

        p.flush()

        self.assertEqual(droppyCert.verifyCount, 2)

        d = cli.callRemote(SecuredPing)
        p.flush()

        # it might be a good idea to move this exception somewhere more
        # reasonable.
        self.assertFailure(d, error.PeerVerifyError)


class TLSNotAvailableTests(TestCase):
    """
    Tests what happened when ssl is not available in current installation.
    """

    def setUp(self):
        """
        Disable ssl in amp.
        """
        self.ssl = amp.ssl
        amp.ssl = None

    def tearDown(self):
        """
        Restore ssl module.
        """
        amp.ssl = self.ssl

    def test_callRemoteError(self):
        """
        Check that callRemote raises an exception when called with a
        L{amp.StartTLS}.
        """
        cli, svr, p = connectedServerAndClient(
            ServerClass=SecurableProto, ClientClass=SecurableProto
        )

        okc = OKCert()
        svr.certFactory = lambda: okc

        return self.assertFailure(
            cli.callRemote(
                amp.StartTLS,
                tls_localCertificate=okc,
                tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
            ),
            RuntimeError,
        )

    def test_messageReceivedError(self):
        """
        When a client with SSL enabled talks to a server without SSL, it
        should return a meaningful error.
        """
        svr = SecurableProto()
        okc = OKCert()
        svr.certFactory = lambda: okc
        box = amp.Box()
        box[b"_command"] = b"StartTLS"
        box[b"_ask"] = b"1"
        boxes = []
        svr.sendBox = boxes.append
        svr.makeConnection(StringTransport())
        svr.ampBoxReceived(box)
        self.assertEqual(
            boxes,
            [
                {
                    b"_error_code": b"TLS_ERROR",
                    b"_error": b"1",
                    b"_error_description": b"TLS not available",
                }
            ],
        )


class InheritedError(Exception):
    """
    This error is used to check inheritance.
    """


class OtherInheritedError(Exception):
    """
    This is a distinct error for checking inheritance.
    """


class BaseCommand(amp.Command):
    """
    This provides a command that will be subclassed.
    """

    errors: Dict[Type[Exception], bytes] = {InheritedError: b"INHERITED_ERROR"}


class InheritedCommand(BaseCommand):
    """
    This is a command which subclasses another command but does not override
    anything.
    """


class AddErrorsCommand(BaseCommand):
    """
    This is a command which subclasses another command but adds errors to the
    list.
    """

    arguments = [(b"other", amp.Boolean())]
    errors: Dict[Type[Exception], bytes] = {
        OtherInheritedError: b"OTHER_INHERITED_ERROR"
    }


class NormalCommandProtocol(amp.AMP):
    """
    This is a protocol which responds to L{BaseCommand}, and is used to test
    that inheritance does not interfere with the normal handling of errors.
    """

    def resp(self):
        raise InheritedError()

    BaseCommand.responder(resp)


class InheritedCommandProtocol(amp.AMP):
    """
    This is a protocol which responds to L{InheritedCommand}, and is used to
    test that inherited commands inherit their bases' errors if they do not
    respond to any of their own.
    """

    def resp(self):
        raise InheritedError()

    InheritedCommand.responder(resp)


class AddedCommandProtocol(amp.AMP):
    """
    This is a protocol which responds to L{AddErrorsCommand}, and is used to
    test that inherited commands can add their own new types of errors, but
    still respond in the same way to their parents types of errors.
    """

    def resp(self, other):
        if other:
            raise OtherInheritedError()
        else:
            raise InheritedError()

    AddErrorsCommand.responder(resp)


class CommandInheritanceTests(TestCase):
    """
    These tests verify that commands inherit error conditions properly.
    """

    def errorCheck(self, err, proto, cmd, **kw):
        """
        Check that the appropriate kind of error is raised when a given command
        is sent to a given protocol.
        """
        c, s, p = connectedServerAndClient(ServerClass=proto, ClientClass=proto)
        d = c.callRemote(cmd, **kw)
        d2 = self.failUnlessFailure(d, err)
        p.flush()
        return d2

    def test_basicErrorPropagation(self):
        """
        Verify that errors specified in a superclass are respected normally
        even if it has subclasses.
        """
        return self.errorCheck(InheritedError, NormalCommandProtocol, BaseCommand)

    def test_inheritedErrorPropagation(self):
        """
        Verify that errors specified in a superclass command are propagated to
        its subclasses.
        """
        return self.errorCheck(
            InheritedError, InheritedCommandProtocol, InheritedCommand
        )

    def test_inheritedErrorAddition(self):
        """
        Verify that new errors specified in a subclass of an existing command
        are honored even if the superclass defines some errors.
        """
        return self.errorCheck(
            OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=True
        )

    def test_additionWithOriginalError(self):
        """
        Verify that errors specified in a command's superclass are respected
        even if that command defines new errors itself.
        """
        return self.errorCheck(
            InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False
        )


def _loseAndPass(err, proto):
    # be specific, pass on the error to the client.
    err.trap(error.ConnectionLost, error.ConnectionDone)
    del proto.connectionLost
    proto.connectionLost(err)


class LiveFireBase:
    """
    Utility for connected reactor-using tests.
    """

    def setUp(self):
        """
        Create an amp server and connect a client to it.
        """
        from twisted.internet import reactor

        self.serverFactory = protocol.ServerFactory()
        self.serverFactory.protocol = self.serverProto
        self.clientFactory = protocol.ClientFactory()
        self.clientFactory.protocol = self.clientProto
        self.clientFactory.onMade = defer.Deferred()
        self.serverFactory.onMade = defer.Deferred()
        self.serverPort = reactor.listenTCP(0, self.serverFactory)
        self.addCleanup(self.serverPort.stopListening)
        self.clientConn = reactor.connectTCP(
            "127.0.0.1", self.serverPort.getHost().port, self.clientFactory
        )
        self.addCleanup(self.clientConn.disconnect)

        def getProtos(rlst):
            self.cli = self.clientFactory.theProto
            self.svr = self.serverFactory.theProto

        dl = defer.DeferredList([self.clientFactory.onMade, self.serverFactory.onMade])
        return dl.addCallback(getProtos)

    def tearDown(self):
        """
        Cleanup client and server connections, and check the error got at
        C{connectionLost}.
        """
        L = []
        for conn in self.cli, self.svr:
            if conn.transport is not None:
                # depend on amp's function connection-dropping behavior
                d = defer.Deferred().addErrback(_loseAndPass, conn)
                conn.connectionLost = d.errback
                conn.transport.loseConnection()
                L.append(d)
        return defer.gatherResults(L).addErrback(lambda first: first.value.subFailure)


def show(x):
    import sys

    sys.stdout.write(x + "\n")
    sys.stdout.flush()


def tempSelfSigned():
    from twisted.internet import ssl

    sharedDN = ssl.DN(CN="shared")
    key = ssl.KeyPair.generate()
    cr = key.certificateRequest(sharedDN)
    sscrd = key.signCertificateRequest(sharedDN, cr, lambda dn: True, 1234567)
    cert = key.newCertificate(sscrd)
    return cert


if ssl is not None:
    tempcert = tempSelfSigned()


@skipIf(skipSSL, "SSL not available")
@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
class LiveFireTLSTests(LiveFireBase, TestCase):

    clientProto = SecurableProto
    serverProto = SecurableProto

    def test_liveFireCustomTLS(self):
        """
        Using real, live TLS, actually negotiate a connection.

        This also looks at the 'peerCertificate' attribute's correctness, since
        that's actually loaded using OpenSSL calls, but the main purpose is to
        make sure that we didn't miss anything obvious in iosim about TLS
        negotiations.
        """

        cert = tempcert

        self.svr.verifyFactory = lambda: [cert]
        self.svr.certFactory = lambda: cert
        # only needed on the server, we specify the client below.

        def secured(rslt):
            x = cert.digest()

            def pinged(rslt2):
                # Interesting.  OpenSSL won't even _tell_ us about the peer
                # cert until we negotiate.  we should be able to do this in
                # 'secured' instead, but it looks like we can't.  I think this
                # is a bug somewhere far deeper than here.
                self.assertEqual(x, self.cli.hostCertificate.digest())
                self.assertEqual(x, self.cli.peerCertificate.digest())
                self.assertEqual(x, self.svr.hostCertificate.digest())
                self.assertEqual(x, self.svr.peerCertificate.digest())

            return self.cli.callRemote(SecuredPing).addCallback(pinged)

        return self.cli.callRemote(
            amp.StartTLS, tls_localCertificate=cert, tls_verifyAuthorities=[cert]
        ).addCallback(secured)


class SlightlySmartTLS(SimpleSymmetricCommandProtocol):
    """
    Specific implementation of server side protocol with different
    management of TLS.
    """

    def getTLSVars(self):
        """
        @return: the global C{tempcert} certificate as local certificate.
        """
        return dict(tls_localCertificate=tempcert)

    amp.StartTLS.responder(getTLSVars)


@skipIf(skipSSL, "SSL not available")
@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
class PlainVanillaLiveFireTests(LiveFireBase, TestCase):

    clientProto = SimpleSymmetricCommandProtocol
    serverProto = SimpleSymmetricCommandProtocol

    def test_liveFireDefaultTLS(self):
        """
        Verify that out of the box, we can start TLS to at least encrypt the
        connection, even if we don't have any certificates to use.
        """

        def secured(result):
            return self.cli.callRemote(SecuredPing)

        return self.cli.callRemote(amp.StartTLS).addCallback(secured)


@skipIf(skipSSL, "SSL not available")
@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
class WithServerTLSVerificationTests(LiveFireBase, TestCase):

    clientProto = SimpleSymmetricCommandProtocol
    serverProto = SlightlySmartTLS

    def test_anonymousVerifyingClient(self):
        """
        Verify that anonymous clients can verify server certificates.
        """

        def secured(result):
            return self.cli.callRemote(SecuredPing)

        return self.cli.callRemote(
            amp.StartTLS, tls_verifyAuthorities=[tempcert]
        ).addCallback(secured)


class ProtocolIncludingArgument(amp.Argument):
    """
    An L{amp.Argument} which encodes its parser and serializer
    arguments *including the protocol* into its parsed and serialized
    forms.
    """

    def fromStringProto(self, string, protocol):
        """
        Don't decode anything; just return all possible information.

        @return: A two-tuple of the input string and the protocol.
        """
        return (string, protocol)

    def toStringProto(self, obj, protocol):
        """
        Encode identifying information about L{object} and protocol
        into a string for later verification.

        @type obj: L{object}
        @type protocol: L{amp.AMP}
        """
        ident = "%d:%d" % (id(obj), id(protocol))
        return ident.encode("ascii")


class ProtocolIncludingCommand(amp.Command):
    """
    A command that has argument and response schemas which use
    L{ProtocolIncludingArgument}.
    """

    arguments = [(b"weird", ProtocolIncludingArgument())]
    response = [(b"weird", ProtocolIncludingArgument())]


class MagicSchemaCommand(amp.Command):
    """
    A command which overrides L{parseResponse}, L{parseArguments}, and
    L{makeResponse}.
    """

    @classmethod
    def parseResponse(self, strings, protocol):
        """
        Don't do any parsing, just jam the input strings and protocol
        onto the C{protocol.parseResponseArguments} attribute as a
        two-tuple. Return the original strings.
        """
        protocol.parseResponseArguments = (strings, protocol)
        return strings

    @classmethod
    def parseArguments(cls, strings, protocol):
        """
        Don't do any parsing, just jam the input strings and protocol
        onto the C{protocol.parseArgumentsArguments} attribute as a
        two-tuple. Return the original strings.
        """
        protocol.parseArgumentsArguments = (strings, protocol)
        return strings

    @classmethod
    def makeArguments(cls, objects, protocol):
        """
        Don't do any serializing, just jam the input strings and protocol
        onto the C{protocol.makeArgumentsArguments} attribute as a
        two-tuple. Return the original strings.
        """
        protocol.makeArgumentsArguments = (objects, protocol)
        return objects


class NoNetworkProtocol(amp.AMP):
    """
    An L{amp.AMP} subclass which overrides private methods to avoid
    testing the network. It also provides a responder for
    L{MagicSchemaCommand} that does nothing, so that tests can test
    aspects of the interaction of L{amp.Command}s and L{amp.AMP}.

    @ivar parseArgumentsArguments: Arguments that have been passed to any
    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
    this protocol.

    @ivar parseResponseArguments: Responses that have been returned from a
    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
    this protocol.

    @ivar makeArgumentsArguments: Arguments that have been serialized by any
    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
    this protocol.
    """

    def _sendBoxCommand(self, commandName, strings, requiresAnswer):
        """
        Return a Deferred which fires with the original strings.
        """
        return defer.succeed(strings)

    MagicSchemaCommand.responder(lambda s, weird: {})


class MyBox(dict):
    """
    A unique dict subclass.
    """


class ProtocolIncludingCommandWithDifferentCommandType(ProtocolIncludingCommand):
    """
    A L{ProtocolIncludingCommand} subclass whose commandType is L{MyBox}
    """

    commandType = MyBox  # type: ignore[assignment]


class CommandTests(TestCase):
    """
    Tests for L{amp.Argument} and L{amp.Command}.
    """

    def test_argumentInterface(self):
        """
        L{Argument} instances provide L{amp.IArgumentType}.
        """
        self.assertTrue(verifyObject(amp.IArgumentType, amp.Argument()))

    def test_parseResponse(self):
        """
        There should be a class method of Command which accepts a
        mapping of argument names to serialized forms and returns a
        similar mapping whose values have been parsed via the
        Command's response schema.
        """
        protocol = object()
        result = b"whatever"
        strings = {b"weird": result}
        self.assertEqual(
            ProtocolIncludingCommand.parseResponse(strings, protocol),
            {"weird": (result, protocol)},
        )

    def test_callRemoteCallsParseResponse(self):
        """
        Making a remote call on a L{amp.Command} subclass which
        overrides the C{parseResponse} method should call that
        C{parseResponse} method to get the response.
        """
        client = NoNetworkProtocol()
        thingy = b"weeoo"
        response = client.callRemote(MagicSchemaCommand, weird=thingy)

        def gotResponse(ign):
            self.assertEqual(client.parseResponseArguments, ({"weird": thingy}, client))

        response.addCallback(gotResponse)
        return response

    def test_parseArguments(self):
        """
        There should be a class method of L{amp.Command} which accepts
        a mapping of argument names to serialized forms and returns a
        similar mapping whose values have been parsed via the
        command's argument schema.
        """
        protocol = object()
        result = b"whatever"
        strings = {b"weird": result}
        self.assertEqual(
            ProtocolIncludingCommand.parseArguments(strings, protocol),
            {"weird": (result, protocol)},
        )

    def test_responderCallsParseArguments(self):
        """
        Making a remote call on a L{amp.Command} subclass which
        overrides the C{parseArguments} method should call that
        C{parseArguments} method to get the arguments.
        """
        protocol = NoNetworkProtocol()
        responder = protocol.locateResponder(MagicSchemaCommand.commandName)
        argument = object()
        response = responder(dict(weird=argument))
        response.addCallback(
            lambda ign: self.assertEqual(
                protocol.parseArgumentsArguments, ({"weird": argument}, protocol)
            )
        )
        return response

    def test_makeArguments(self):
        """
        There should be a class method of L{amp.Command} which accepts
        a mapping of argument names to objects and returns a similar
        mapping whose values have been serialized via the command's
        argument schema.
        """
        protocol = object()
        argument = object()
        objects = {"weird": argument}
        ident = "%d:%d" % (id(argument), id(protocol))
        self.assertEqual(
            ProtocolIncludingCommand.makeArguments(objects, protocol),
            {b"weird": ident.encode("ascii")},
        )

    def test_makeArgumentsUsesCommandType(self):
        """
        L{amp.Command.makeArguments}'s return type should be the type
        of the result of L{amp.Command.commandType}.
        """
        protocol = object()
        objects = {"weird": b"whatever"}

        result = ProtocolIncludingCommandWithDifferentCommandType.makeArguments(
            objects, protocol
        )
        self.assertIs(type(result), MyBox)

    def test_callRemoteCallsMakeArguments(self):
        """
        Making a remote call on a L{amp.Command} subclass which
        overrides the C{makeArguments} method should call that
        C{makeArguments} method to get the response.
        """
        client = NoNetworkProtocol()
        argument = object()
        response = client.callRemote(MagicSchemaCommand, weird=argument)

        def gotResponse(ign):
            self.assertEqual(
                client.makeArgumentsArguments, ({"weird": argument}, client)
            )

        response.addCallback(gotResponse)
        return response

    def test_extraArgumentsDisallowed(self):
        """
        L{Command.makeArguments} raises L{amp.InvalidSignature} if the objects
        dictionary passed to it includes a key which does not correspond to the
        Python identifier for a defined argument.
        """
        self.assertRaises(
            amp.InvalidSignature,
            Hello.makeArguments,
            dict(hello="hello", bogusArgument=object()),
            None,
        )

    def test_wireSpellingDisallowed(self):
        """
        If a command argument conflicts with a Python keyword, the
        untransformed argument name is not allowed as a key in the dictionary
        passed to L{Command.makeArguments}.  If it is supplied,
        L{amp.InvalidSignature} is raised.

        This may be a pointless implementation restriction which may be lifted.
        The current behavior is tested to verify that such arguments are not
        silently dropped on the floor (the previous behavior).
        """
        self.assertRaises(
            amp.InvalidSignature,
            Hello.makeArguments,
            dict(hello="required", **{"print": "print value"}),
            None,
        )

    def test_commandNameDefaultsToClassNameAsByteString(self):
        """
        A L{Command} subclass without a defined C{commandName} that's
        not a byte string.
        """

        class NewCommand(amp.Command):
            """
            A new command.
            """

        self.assertEqual(b"NewCommand", NewCommand.commandName)

    def test_commandNameMustBeAByteString(self):
        """
        A L{Command} subclass cannot be defined with a C{commandName} that's
        not a byte string.
        """
        error = self.assertRaises(
            TypeError, type, "NewCommand", (amp.Command,), {"commandName": "FOO"}
        )
        self.assertRegex(
            str(error), "^Command names must be byte strings, got: u?'FOO'$"
        )

    def test_commandArgumentsMustBeNamedWithByteStrings(self):
        """
        A L{Command} subclass's C{arguments} must have byte string names.
        """
        error = self.assertRaises(
            TypeError,
            type,
            "NewCommand",
            (amp.Command,),
            {"arguments": [("foo", None)]},
        )
        self.assertRegex(
            str(error), "^Argument names must be byte strings, got: u?'foo'$"
        )

    def test_commandResponseMustBeNamedWithByteStrings(self):
        """
        A L{Command} subclass's C{response} must have byte string names.
        """
        error = self.assertRaises(
            TypeError, type, "NewCommand", (amp.Command,), {"response": [("foo", None)]}
        )
        self.assertRegex(
            str(error), "^Response names must be byte strings, got: u?'foo'$"
        )

    def test_commandErrorsIsConvertedToDict(self):
        """
        A L{Command} subclass's C{errors} is coerced into a C{dict}.
        """

        class NewCommand(amp.Command):
            errors = [(ZeroDivisionError, b"ZDE")]

        self.assertEqual({ZeroDivisionError: b"ZDE"}, NewCommand.errors)

    def test_commandErrorsMustUseBytesForOnWireRepresentation(self):
        """
        A L{Command} subclass's C{errors} must map exceptions to byte strings.
        """
        error = self.assertRaises(
            TypeError,
            type,
            "NewCommand",
            (amp.Command,),
            {"errors": [(ZeroDivisionError, "foo")]},
        )
        self.assertRegex(str(error), "^Error names must be byte strings, got: u?'foo'$")

    def test_commandFatalErrorsIsConvertedToDict(self):
        """
        A L{Command} subclass's C{fatalErrors} is coerced into a C{dict}.
        """

        class NewCommand(amp.Command):
            fatalErrors = [(ZeroDivisionError, b"ZDE")]

        self.assertEqual({ZeroDivisionError: b"ZDE"}, NewCommand.fatalErrors)

    def test_commandFatalErrorsMustUseBytesForOnWireRepresentation(self):
        """
        A L{Command} subclass's C{fatalErrors} must map exceptions to byte
        strings.
        """
        error = self.assertRaises(
            TypeError,
            type,
            "NewCommand",
            (amp.Command,),
            {"fatalErrors": [(ZeroDivisionError, "foo")]},
        )
        self.assertRegex(
            str(error), "^Fatal error names must be byte strings, " "got: u?'foo'$"
        )


class ListOfTestsMixin:
    """
    Base class for testing L{ListOf}, a parameterized zero-or-more argument
    type.

    @ivar elementType: Subclasses should set this to an L{Argument}
        instance.  The tests will make a L{ListOf} using this.

    @ivar strings: Subclasses should set this to a dictionary mapping some
        number of keys -- as BYTE strings -- to the correct serialized form
        for some example values. These should agree with what L{elementType}
        produces/accepts.

    @ivar objects: Subclasses should set this to a dictionary with the same
        keys as C{strings} -- as NATIVE strings -- and with values which are
        the lists which should serialize to the values in the C{strings}
        dictionary.
    """

    def test_toBox(self):
        """
        L{ListOf.toBox} extracts the list of objects from the C{objects}
        dictionary passed to it, using the C{name} key also passed to it,
        serializes each of the elements in that list using the L{Argument}
        instance previously passed to its initializer, combines the serialized
        results, and inserts the result into the C{strings} dictionary using
        the same C{name} key.
        """
        stringList = amp.ListOf(self.elementType)
        strings = amp.AmpBox()
        for key in self.objects:
            stringList.toBox(key.encode("ascii"), strings, self.objects.copy(), None)
        self.assertEqual(strings, self.strings)

    def test_fromBox(self):
        """
        L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}.
        """
        stringList = amp.ListOf(self.elementType)
        objects = {}
        for key in self.strings:
            stringList.fromBox(key, self.strings.copy(), objects, None)
        self.assertEqual(objects, self.objects)


class ListOfStringsTests(TestCase, ListOfTestsMixin):
    """
    Tests for L{ListOf} combined with L{amp.String}.
    """

    elementType = amp.String()

    strings = {
        b"empty": b"",
        b"single": b"\x00\x03foo",
        b"multiple": b"\x00\x03bar\x00\x03baz\x00\x04quux",
    }

    objects = {"empty": [], "single": [b"foo"], "multiple": [b"bar", b"baz", b"quux"]}


class ListOfIntegersTests(TestCase, ListOfTestsMixin):
    """
    Tests for L{ListOf} combined with L{amp.Integer}.
    """

    elementType = amp.Integer()

    huge = (
        9999999999999999999999999999999999999999999999999999999999
        * 9999999999999999999999999999999999999999999999999999999999
    )

    strings = {
        b"empty": b"",
        b"single": b"\x00\x0210",
        b"multiple": b"\x00\x011\x00\x0220\x00\x03500",
        b"huge": b"\x00\x74%d" % (huge,),
        b"negative": b"\x00\x02-1",
    }

    objects = {
        "empty": [],
        "single": [10],
        "multiple": [1, 20, 500],
        "huge": [huge],
        "negative": [-1],
    }


class ListOfUnicodeTests(TestCase, ListOfTestsMixin):
    """
    Tests for L{ListOf} combined with L{amp.Unicode}.
    """

    elementType = amp.Unicode()

    strings = {
        b"empty": b"",
        b"single": b"\x00\x03foo",
        b"multiple": b"\x00\x03\xe2\x98\x83\x00\x05Hello\x00\x05world",
    }

    objects = {
        "empty": [],
        "single": ["foo"],
        "multiple": ["\N{SNOWMAN}", "Hello", "world"],
    }


class ListOfDecimalTests(TestCase, ListOfTestsMixin):
    """
    Tests for L{ListOf} combined with L{amp.Decimal}.
    """

    elementType = amp.Decimal()

    strings = {
        b"empty": b"",
        b"single": b"\x00\x031.1",
        b"extreme": b"\x00\x08Infinity\x00\x09-Infinity",
        b"scientist": b"\x00\x083.141E+5\x00\x0a0.00003141\x00\x083.141E-7"
        b"\x00\x09-3.141E+5\x00\x0b-0.00003141\x00\x09-3.141E-7",
        b"engineer": (
            b"\x00\x04"
            + decimal.Decimal("0e6").to_eng_string().encode("ascii")
            + b"\x00\x06"
            + decimal.Decimal("1.5E-9").to_eng_string().encode("ascii")
        ),
    }

    objects = {
        "empty": [],
        "single": [decimal.Decimal("1.1")],
        "extreme": [
            decimal.Decimal("Infinity"),
            decimal.Decimal("-Infinity"),
        ],
        # exarkun objected to AMP supporting engineering notation because
        # it was redundant, until we realised that 1E6 has less precision
        # than 1000000 and is represented differently.  But they compare
        # and even hash equally.  There were tears.
        "scientist": [
            decimal.Decimal("3.141E5"),
            decimal.Decimal("3.141e-5"),
            decimal.Decimal("3.141E-7"),
            decimal.Decimal("-3.141e5"),
            decimal.Decimal("-3.141E-5"),
            decimal.Decimal("-3.141e-7"),
        ],
        "engineer": [
            decimal.Decimal("0e6"),
            decimal.Decimal("1.5E-9"),
        ],
    }


class ListOfDecimalNanTests(TestCase, ListOfTestsMixin):
    """
    Tests for L{ListOf} combined with L{amp.Decimal} for not-a-number values.
    """

    elementType = amp.Decimal()

    strings = {
        b"nan": b"\x00\x03NaN\x00\x04-NaN\x00\x04sNaN\x00\x05-sNaN",
    }

    objects = {
        "nan": [
            decimal.Decimal("NaN"),
            decimal.Decimal("-NaN"),
            decimal.Decimal("sNaN"),
            decimal.Decimal("-sNaN"),
        ]
    }

    def test_fromBox(self):
        """
        L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}.
        """
        # Helpers.  Decimal.is_{qnan,snan,signed}() are new in 2.6 (or 2.5.2,
        # but who's counting).
        def is_qnan(decimal):
            return "NaN" in str(decimal) and "sNaN" not in str(decimal)

        def is_snan(decimal):
            return "sNaN" in str(decimal)

        def is_signed(decimal):
            return "-" in str(decimal)

        # NaN values have unusual equality semantics, so this method is
        # overridden to compare the resulting objects in a way which works with
        # NaNs.
        stringList = amp.ListOf(self.elementType)
        objects = {}
        for key in self.strings:
            stringList.fromBox(key, self.strings.copy(), objects, None)
        n = objects["nan"]
        self.assertTrue(is_qnan(n[0]) and not is_signed(n[0]))
        self.assertTrue(is_qnan(n[1]) and is_signed(n[1]))
        self.assertTrue(is_snan(n[2]) and not is_signed(n[2]))
        self.assertTrue(is_snan(n[3]) and is_signed(n[3]))


class DecimalTests(TestCase):
    """
    Tests for L{amp.Decimal}.
    """

    def test_nonDecimal(self):
        """
        L{amp.Decimal.toString} raises L{ValueError} if passed an object which
        is not an instance of C{decimal.Decimal}.
        """
        argument = amp.Decimal()
        self.assertRaises(ValueError, argument.toString, "1.234")
        self.assertRaises(ValueError, argument.toString, 1.234)
        self.assertRaises(ValueError, argument.toString, 1234)


class FloatTests(TestCase):
    """
    Tests for L{amp.Float}.
    """

    def test_nonFloat(self):
        """
        L{amp.Float.toString} raises L{ValueError} if passed an object which
        is not a L{float}.
        """
        argument = amp.Float()
        self.assertRaises(ValueError, argument.toString, "1.234")
        self.assertRaises(ValueError, argument.toString, b"1.234")
        self.assertRaises(ValueError, argument.toString, 1234)

    def test_float(self):
        """
        L{amp.Float.toString} returns a bytestring when it is given a L{float}.
        """
        argument = amp.Float()
        self.assertEqual(argument.toString(1.234), b"1.234")


class ListOfDateTimeTests(TestCase, ListOfTestsMixin):
    """
    Tests for L{ListOf} combined with L{amp.DateTime}.
    """

    elementType = amp.DateTime()

    strings = {
        b"christmas": b"\x00\x202010-12-25T00:00:00.000000-00:00"
        b"\x00\x202010-12-25T00:00:00.000000-00:00",
        b"christmas in eu": b"\x00\x202010-12-25T00:00:00.000000+01:00",
        b"christmas in iran": b"\x00\x202010-12-25T00:00:00.000000+03:30",
        b"christmas in nyc": b"\x00\x202010-12-25T00:00:00.000000-05:00",
        b"previous tests": b"\x00\x202010-12-25T00:00:00.000000+03:19"
        b"\x00\x202010-12-25T00:00:00.000000-06:59",
    }

    objects = {
        "christmas": [
            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=amp.utc),
            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 0, 0)),
        ],
        "christmas in eu": [
            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 1, 0)),
        ],
        "christmas in iran": [
            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 3, 30)),
        ],
        "christmas in nyc": [
            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("-", 5, 0)),
        ],
        "previous tests": [
            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 3, 19)),
            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("-", 6, 59)),
        ],
    }


class ListOfOptionalTests(TestCase):
    """
    Tests to ensure L{ListOf} AMP arguments can be omitted from AMP commands
    via the 'optional' flag.
    """

    def test_requiredArgumentWithNoneValueRaisesTypeError(self):
        """
        L{ListOf.toBox} raises C{TypeError} when passed a value of L{None}
        for the argument.
        """
        stringList = amp.ListOf(amp.Integer())
        self.assertRaises(
            TypeError,
            stringList.toBox,
            b"omitted",
            amp.AmpBox(),
            {"omitted": None},
            None,
        )

    def test_optionalArgumentWithNoneValueOmitted(self):
        """
        L{ListOf.toBox} silently omits serializing any argument with a
        value of L{None} that is designated as optional for the protocol.
        """
        stringList = amp.ListOf(amp.Integer(), optional=True)
        strings = amp.AmpBox()
        stringList.toBox(b"omitted", strings, {b"omitted": None}, None)
        self.assertEqual(strings, {})

    def test_requiredArgumentWithKeyMissingRaisesKeyError(self):
        """
        L{ListOf.toBox} raises C{KeyError} if the argument's key is not
        present in the objects dictionary.
        """
        stringList = amp.ListOf(amp.Integer())
        self.assertRaises(
            KeyError,
            stringList.toBox,
            b"ommited",
            amp.AmpBox(),
            {"someOtherKey": 0},
            None,
        )

    def test_optionalArgumentWithKeyMissingOmitted(self):
        """
        L{ListOf.toBox} silently omits serializing any argument designated
        as optional whose key is not present in the objects dictionary.
        """
        stringList = amp.ListOf(amp.Integer(), optional=True)
        stringList.toBox(b"ommited", amp.AmpBox(), {b"someOtherKey": 0}, None)

    def test_omittedOptionalArgumentDeserializesAsNone(self):
        """
        L{ListOf.fromBox} correctly reverses the operation performed by
        L{ListOf.toBox} for optional arguments.
        """
        stringList = amp.ListOf(amp.Integer(), optional=True)
        objects = {}
        stringList.fromBox(b"omitted", {}, objects, None)
        self.assertEqual(objects, {"omitted": None})


@implementer(interfaces.IUNIXTransport)
class UNIXStringTransport:
    """
    An in-memory implementation of L{interfaces.IUNIXTransport} which collects
    all data given to it for later inspection.

    @ivar _queue: A C{list} of the data which has been given to this transport,
        eg via C{write} or C{sendFileDescriptor}.  Elements are two-tuples of a
        string (identifying the destination of the data) and the data itself.
    """

    def __init__(self, descriptorFuzz):
        """
        @param descriptorFuzz: An offset to apply to descriptors.
        @type descriptorFuzz: C{int}
        """
        self._fuzz = descriptorFuzz
        self._queue = []

    def sendFileDescriptor(self, descriptor):
        self._queue.append(("fileDescriptorReceived", descriptor + self._fuzz))

    def write(self, data):
        self._queue.append(("dataReceived", data))

    def writeSequence(self, seq):
        for data in seq:
            self.write(data)

    def loseConnection(self):
        self._queue.append(("connectionLost", Failure(error.ConnectionLost())))

    def getHost(self):
        return address.UNIXAddress("/tmp/some-path")

    def getPeer(self):
        return address.UNIXAddress("/tmp/another-path")


# Minimal evidence that we got the signatures right
verifyClass(interfaces.ITransport, UNIXStringTransport)
verifyClass(interfaces.IUNIXTransport, UNIXStringTransport)


class DescriptorTests(TestCase):
    """
    Tests for L{amp.Descriptor}, an argument type for passing a file descriptor
    over an AMP connection over a UNIX domain socket.
    """

    def setUp(self):
        self.fuzz = 3
        self.transport = UNIXStringTransport(descriptorFuzz=self.fuzz)
        self.protocol = amp.BinaryBoxProtocol(amp.BoxDispatcher(amp.CommandLocator()))
        self.protocol.makeConnection(self.transport)

    def test_fromStringProto(self):
        """
        L{Descriptor.fromStringProto} constructs a file descriptor value by
        extracting a previously received file descriptor corresponding to the
        wire value of the argument from the L{_DescriptorExchanger} state of the
        protocol passed to it.

        This is a whitebox test which involves direct L{_DescriptorExchanger}
        state inspection.
        """
        argument = amp.Descriptor()
        self.protocol.fileDescriptorReceived(5)
        self.protocol.fileDescriptorReceived(3)
        self.protocol.fileDescriptorReceived(1)
        self.assertEqual(5, argument.fromStringProto("0", self.protocol))
        self.assertEqual(3, argument.fromStringProto("1", self.protocol))
        self.assertEqual(1, argument.fromStringProto("2", self.protocol))
        self.assertEqual({}, self.protocol._descriptors)

    def test_toStringProto(self):
        """
        To send a file descriptor, L{Descriptor.toStringProto} uses the
        L{IUNIXTransport.sendFileDescriptor} implementation of the transport of
        the protocol passed to it to copy the file descriptor.  Each subsequent
        descriptor sent over a particular AMP connection is assigned the next
        integer value, starting from 0.  The base ten string representation of
        this value is the byte encoding of the argument.

        This is a whitebox test which involves direct L{_DescriptorExchanger}
        state inspection and mutation.
        """
        argument = amp.Descriptor()
        self.assertEqual(b"0", argument.toStringProto(2, self.protocol))
        self.assertEqual(
            ("fileDescriptorReceived", 2 + self.fuzz), self.transport._queue.pop(0)
        )
        self.assertEqual(b"1", argument.toStringProto(4, self.protocol))
        self.assertEqual(
            ("fileDescriptorReceived", 4 + self.fuzz), self.transport._queue.pop(0)
        )
        self.assertEqual(b"2", argument.toStringProto(6, self.protocol))
        self.assertEqual(
            ("fileDescriptorReceived", 6 + self.fuzz), self.transport._queue.pop(0)
        )
        self.assertEqual({}, self.protocol._descriptors)

    def test_roundTrip(self):
        """
        L{amp.Descriptor.fromBox} can interpret an L{amp.AmpBox} constructed by
        L{amp.Descriptor.toBox} to reconstruct a file descriptor value.
        """
        name = "alpha"
        nameAsBytes = name.encode("ascii")
        strings = {}
        descriptor = 17
        sendObjects = {name: descriptor}

        argument = amp.Descriptor()
        argument.toBox(nameAsBytes, strings, sendObjects.copy(), self.protocol)

        receiver = amp.BinaryBoxProtocol(amp.BoxDispatcher(amp.CommandLocator()))
        for event in self.transport._queue:
            getattr(receiver, event[0])(*event[1:])

        receiveObjects = {}
        argument.fromBox(nameAsBytes, strings.copy(), receiveObjects, receiver)

        # Make sure we got the descriptor.  Adjust by fuzz to be more convincing
        # of having gone through L{IUNIXTransport.sendFileDescriptor}, not just
        # converted to a string and then parsed back into an integer.
        self.assertEqual(descriptor + self.fuzz, receiveObjects[name])


class DateTimeTests(TestCase):
    """
    Tests for L{amp.DateTime}, L{amp._FixedOffsetTZInfo}, and L{amp.utc}.
    """

    string = b"9876-01-23T12:34:56.054321-01:23"
    tzinfo = tz("-", 1, 23)
    object = datetime.datetime(9876, 1, 23, 12, 34, 56, 54321, tzinfo)

    def test_invalidString(self):
        """
        L{amp.DateTime.fromString} raises L{ValueError} when passed a string
        which does not represent a timestamp in the proper format.
        """
        d = amp.DateTime()
        self.assertRaises(ValueError, d.fromString, "abc")

    def test_invalidDatetime(self):
        """
        L{amp.DateTime.toString} raises L{ValueError} when passed a naive
        datetime (a datetime with no timezone information).
        """
        d = amp.DateTime()
        self.assertRaises(
            ValueError, d.toString, datetime.datetime(2010, 12, 25, 0, 0, 0)
        )

    def test_fromString(self):
        """
        L{amp.DateTime.fromString} returns a C{datetime.datetime} with all of
        its fields populated from the string passed to it.
        """
        argument = amp.DateTime()
        value = argument.fromString(self.string)
        self.assertEqual(value, self.object)

    def test_toString(self):
        """
        L{amp.DateTime.toString} returns a C{str} in the wire format including
        all of the information from the C{datetime.datetime} passed into it,
        including the timezone offset.
        """
        argument = amp.DateTime()
        value = argument.toString(self.object)
        self.assertEqual(value, self.string)


class UTCTests(TestCase):
    """
    Tests for L{amp.utc}.
    """

    def test_tzname(self):
        """
        L{amp.utc.tzname} returns C{"+00:00"}.
        """
        self.assertEqual(amp.utc.tzname(None), "+00:00")

    def test_dst(self):
        """
        L{amp.utc.dst} returns a zero timedelta.
        """
        self.assertEqual(amp.utc.dst(None), datetime.timedelta(0))

    def test_utcoffset(self):
        """
        L{amp.utc.utcoffset} returns a zero timedelta.
        """
        self.assertEqual(amp.utc.utcoffset(None), datetime.timedelta(0))

    def test_badSign(self):
        """
        L{amp._FixedOffsetTZInfo.fromSignHoursMinutes} raises L{ValueError} if
        passed an offset sign other than C{'+'} or C{'-'}.
        """
        self.assertRaises(ValueError, tz, "?", 0, 0)


class RemoteAmpErrorTests(TestCase):
    """
    Tests for L{amp.RemoteAmpError}.
    """

    def test_stringMessage(self):
        """
        L{amp.RemoteAmpError} renders the given C{errorCode} (C{bytes}) and
        C{description} into a native string.
        """
        error = amp.RemoteAmpError(b"BROKEN", "Something has broken")
        self.assertEqual("Code<BROKEN>: Something has broken", str(error))

    def test_stringMessageReplacesNonAsciiText(self):
        """
        When C{errorCode} contains non-ASCII characters, L{amp.RemoteAmpError}
        renders then as backslash-escape sequences.
        """
        error = amp.RemoteAmpError(b"BROKEN-\xff", "Something has broken")
        self.assertEqual("Code<BROKEN-\\xff>: Something has broken", str(error))

    def test_stringMessageWithLocalFailure(self):
        """
        L{amp.RemoteAmpError} renders local errors with a "(local)" marker and
        a brief traceback.
        """
        failure = Failure(Exception("Something came loose"))
        error = amp.RemoteAmpError(b"BROKEN", "Something has broken", local=failure)
        self.assertRegex(
            str(error),
            (
                "^Code<BROKEN> [(]local[)]: Something has broken\n"
                "Traceback [(]failure with no frames[)]: "
                "<.+Exception.>: Something came loose\n"
            ),
        )

Zerion Mini Shell 1.0