Mini Shell

Direktori : /opt/imh-python/lib/python3.9/site-packages/twisted/web/test/
Upload File :
Current File : //opt/imh-python/lib/python3.9/site-packages/twisted/web/test/test_agent.py

# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

"""
Tests for L{twisted.web.client.Agent} and related new client APIs.
"""

import zlib
from http.cookiejar import CookieJar
from io import BytesIO
from typing import TYPE_CHECKING, List, Optional, Tuple
from unittest import SkipTest, skipIf

from zope.interface.declarations import implementer
from zope.interface.verify import verifyObject

from incremental import Version

from twisted.internet import defer, task
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import CancelledError, Deferred, succeed
from twisted.internet.endpoints import HostnameEndpoint, TCP4ClientEndpoint
from twisted.internet.error import (
    ConnectionDone,
    ConnectionLost,
    ConnectionRefusedError,
)
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.protocol import Factory, Protocol
from twisted.internet.task import Clock
from twisted.internet.test.test_endpoints import deterministicResolvingReactor
from twisted.logger import globalLogPublisher
from twisted.python.components import proxyForInterface
from twisted.python.deprecate import getDeprecationWarningString
from twisted.python.failure import Failure
from twisted.test.iosim import FakeTransport, IOPump
from twisted.test.proto_helpers import (
    AccumulatingProtocol,
    EventLoggingObserver,
    MemoryReactorClock,
    StringTransport,
)
from twisted.test.test_sslverify import certificatesForAuthorityAndServer
from twisted.trial.unittest import SynchronousTestCase, TestCase
from twisted.web import client, error, http_headers
from twisted.web._newclient import (
    HTTP11ClientProtocol,
    PotentialDataLoss,
    RequestNotSent,
    RequestTransmissionFailed,
    Response,
    ResponseFailed,
    ResponseNeverReceived,
)
from twisted.web.client import (
    URI,
    BrowserLikePolicyForHTTPS,
    FileBodyProducer,
    HostnameCachingHTTPSPolicy,
    HTTPConnectionPool,
    Request,
    ResponseDone,
    _HTTP11ClientFactory,
)
from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
from twisted.web.iweb import (
    UNKNOWN_LENGTH,
    IAgent,
    IAgentEndpointFactory,
    IBodyProducer,
    IPolicyForHTTPS,
    IResponse,
)
from twisted.web.test.injectionhelpers import (
    MethodInjectionTestsMixin,
    URIInjectionTestsMixin,
)

# Creatively lie to mypy about the nature of inheritance, since dealing with
# expectations of a mixin class is basically impossible (don't use mixins).
if TYPE_CHECKING:
    testMixinClass = TestCase
    runtimeTestCase = object
else:
    testMixinClass = object
    runtimeTestCase = TestCase

try:
    from twisted.internet import ssl as _ssl
except ImportError:
    ssl = None
    sslPresent = False
else:
    ssl = _ssl
    sslPresent = True
    from twisted.internet._sslverify import ClientTLSOptions, IOpenSSLTrustRoot
    from twisted.internet.ssl import optionsForClientTLS
    from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol

    @implementer(IOpenSSLTrustRoot)
    class CustomOpenSSLTrustRoot:
        called = False
        context = None

        def _addCACertsToContext(self, context):
            self.called = True
            self.context = context


class StubHTTPProtocol(Protocol):
    """
    A protocol like L{HTTP11ClientProtocol} but which does not actually know
    HTTP/1.1 and only collects requests in a list.

    @ivar requests: A C{list} of two-tuples.  Each time a request is made, a
        tuple consisting of the request and the L{Deferred} returned from the
        request method is appended to this list.
    """

    def __init__(self) -> None:
        self.requests: List[Tuple[Request, Deferred[IResponse]]] = []
        self.state = "QUIESCENT"

    def request(self, request):
        """
        Capture the given request for later inspection.

        @return: A L{Deferred} which this code will never fire.
        """
        result = Deferred()
        self.requests.append((request, result))
        return result


class FileConsumer:
    def __init__(self, outputFile):
        self.outputFile = outputFile

    def write(self, bytes):
        self.outputFile.write(bytes)


class FileBodyProducerTests(TestCase):
    """
    Tests for the L{FileBodyProducer} which reads bytes from a file and writes
    them to an L{IConsumer}.
    """

    def _termination(self):
        """
        This method can be used as the C{terminationPredicateFactory} for a
        L{Cooperator}.  It returns a predicate which immediately returns
        C{False}, indicating that no more work should be done this iteration.
        This has the result of only allowing one iteration of a cooperative
        task to be run per L{Cooperator} iteration.
        """
        return lambda: True

    def setUp(self):
        """
        Create a L{Cooperator} hooked up to an easily controlled, deterministic
        scheduler to use with L{FileBodyProducer}.
        """
        self._scheduled = []
        self.cooperator = task.Cooperator(self._termination, self._scheduled.append)

    def test_interface(self):
        """
        L{FileBodyProducer} instances provide L{IBodyProducer}.
        """
        self.assertTrue(verifyObject(IBodyProducer, FileBodyProducer(BytesIO(b""))))

    def test_unknownLength(self):
        """
        If the L{FileBodyProducer} is constructed with a file-like object
        without either a C{seek} or C{tell} method, its C{length} attribute is
        set to C{UNKNOWN_LENGTH}.
        """

        class HasSeek:
            def seek(self, offset, whence):
                pass

        class HasTell:
            def tell(self):
                pass

        producer = FileBodyProducer(HasSeek())
        self.assertEqual(UNKNOWN_LENGTH, producer.length)
        producer = FileBodyProducer(HasTell())
        self.assertEqual(UNKNOWN_LENGTH, producer.length)

    def test_knownLength(self):
        """
        If the L{FileBodyProducer} is constructed with a file-like object with
        both C{seek} and C{tell} methods, its C{length} attribute is set to the
        size of the file as determined by those methods.
        """
        inputBytes = b"here are some bytes"
        inputFile = BytesIO(inputBytes)
        inputFile.seek(5)
        producer = FileBodyProducer(inputFile)
        self.assertEqual(len(inputBytes) - 5, producer.length)
        self.assertEqual(inputFile.tell(), 5)

    def test_defaultCooperator(self):
        """
        If no L{Cooperator} instance is passed to L{FileBodyProducer}, the
        global cooperator is used.
        """
        producer = FileBodyProducer(BytesIO(b""))
        self.assertEqual(task.cooperate, producer._cooperate)

    def test_startProducing(self):
        """
        L{FileBodyProducer.startProducing} starts writing bytes from the input
        file to the given L{IConsumer} and returns a L{Deferred} which fires
        when they have all been written.
        """
        expectedResult = b"hello, world"
        readSize = 3
        output = BytesIO()
        consumer = FileConsumer(output)
        producer = FileBodyProducer(BytesIO(expectedResult), self.cooperator, readSize)
        complete = producer.startProducing(consumer)
        for i in range(len(expectedResult) // readSize + 1):
            self._scheduled.pop(0)()
        self.assertEqual([], self._scheduled)
        self.assertEqual(expectedResult, output.getvalue())
        self.assertEqual(None, self.successResultOf(complete))

    def test_inputClosedAtEOF(self):
        """
        When L{FileBodyProducer} reaches end-of-file on the input file given to
        it, the input file is closed.
        """
        readSize = 4
        inputBytes = b"some friendly bytes"
        inputFile = BytesIO(inputBytes)
        producer = FileBodyProducer(inputFile, self.cooperator, readSize)
        consumer = FileConsumer(BytesIO())
        producer.startProducing(consumer)
        for i in range(len(inputBytes) // readSize + 2):
            self._scheduled.pop(0)()
        self.assertTrue(inputFile.closed)

    def test_failedReadWhileProducing(self):
        """
        If a read from the input file fails while producing bytes to the
        consumer, the L{Deferred} returned by
        L{FileBodyProducer.startProducing} fires with a L{Failure} wrapping
        that exception.
        """

        class BrokenFile:
            def read(self, count):
                raise OSError("Simulated bad thing")

        producer = FileBodyProducer(BrokenFile(), self.cooperator)
        complete = producer.startProducing(FileConsumer(BytesIO()))
        self._scheduled.pop(0)()
        self.failureResultOf(complete).trap(IOError)

    def test_cancelWhileProducing(self):
        """
        When the L{Deferred} returned by L{FileBodyProducer.startProducing} is
        cancelled, the input file is closed and the task is stopped.
        """
        expectedResult = b"hello, world"
        readSize = 3
        output = BytesIO()
        consumer = FileConsumer(output)
        inputFile = BytesIO(expectedResult)
        producer = FileBodyProducer(inputFile, self.cooperator, readSize)
        complete = producer.startProducing(consumer)
        complete.cancel()
        self.assertTrue(inputFile.closed)
        self._scheduled.pop(0)()
        self.assertEqual(b"", output.getvalue())
        self.assertNoResult(complete)

    def test_stopProducing(self):
        """
        L{FileBodyProducer.stopProducing} stops the underlying L{IPullProducer}
        and the cooperative task responsible for calling C{resumeProducing} and
        closes the input file but does not cause the L{Deferred} returned by
        C{startProducing} to fire.
        """
        expectedResult = b"hello, world"
        readSize = 3
        output = BytesIO()
        consumer = FileConsumer(output)
        inputFile = BytesIO(expectedResult)
        producer = FileBodyProducer(inputFile, self.cooperator, readSize)
        complete = producer.startProducing(consumer)
        producer.stopProducing()
        self.assertTrue(inputFile.closed)
        self._scheduled.pop(0)()
        self.assertEqual(b"", output.getvalue())
        self.assertNoResult(complete)

    def test_pauseProducing(self):
        """
        L{FileBodyProducer.pauseProducing} temporarily suspends writing bytes
        from the input file to the given L{IConsumer}.
        """
        expectedResult = b"hello, world"
        readSize = 5
        output = BytesIO()
        consumer = FileConsumer(output)
        producer = FileBodyProducer(BytesIO(expectedResult), self.cooperator, readSize)
        complete = producer.startProducing(consumer)
        self._scheduled.pop(0)()
        self.assertEqual(output.getvalue(), expectedResult[:5])
        producer.pauseProducing()

        # Sort of depends on an implementation detail of Cooperator: even
        # though the only task is paused, there's still a scheduled call.  If
        # this were to go away because Cooperator became smart enough to cancel
        # this call in this case, that would be fine.
        self._scheduled.pop(0)()

        # Since the producer is paused, no new data should be here.
        self.assertEqual(output.getvalue(), expectedResult[:5])
        self.assertEqual([], self._scheduled)
        self.assertNoResult(complete)

    def test_resumeProducing(self):
        """
        L{FileBodyProducer.resumeProducing} re-commences writing bytes from the
        input file to the given L{IConsumer} after it was previously paused
        with L{FileBodyProducer.pauseProducing}.
        """
        expectedResult = b"hello, world"
        readSize = 5
        output = BytesIO()
        consumer = FileConsumer(output)
        producer = FileBodyProducer(BytesIO(expectedResult), self.cooperator, readSize)
        producer.startProducing(consumer)
        self._scheduled.pop(0)()
        self.assertEqual(expectedResult[:readSize], output.getvalue())
        producer.pauseProducing()
        producer.resumeProducing()
        self._scheduled.pop(0)()
        self.assertEqual(expectedResult[: readSize * 2], output.getvalue())

    def test_multipleStop(self):
        """
        L{FileBodyProducer.stopProducing} can be called more than once without
        raising an exception.
        """
        expectedResult = b"test"
        readSize = 3
        output = BytesIO()
        consumer = FileConsumer(output)
        inputFile = BytesIO(expectedResult)
        producer = FileBodyProducer(inputFile, self.cooperator, readSize)
        complete = producer.startProducing(consumer)
        producer.stopProducing()
        producer.stopProducing()
        self.assertTrue(inputFile.closed)
        self._scheduled.pop(0)()
        self.assertEqual(b"", output.getvalue())
        self.assertNoResult(complete)


EXAMPLE_COM_IP = "127.0.0.7"
EXAMPLE_COM_V6_IP = "::7"
EXAMPLE_NET_IP = "127.0.0.8"
EXAMPLE_ORG_IP = "127.0.0.9"
FOO_LOCAL_IP = "127.0.0.10"
FOO_COM_IP = "127.0.0.11"


class FakeReactorAndConnectMixin:
    """
    A test mixin providing a testable C{Reactor} class and a dummy C{connect}
    method which allows instances to pretend to be endpoints.
    """

    def createReactor(self):
        """
        Create a L{MemoryReactorClock} and give it some hostnames it can
        resolve.

        @return: a L{MemoryReactorClock}-like object with a slightly limited
            interface (only C{advance} and C{tcpClients} in addition to its
            formally-declared reactor interfaces), which can resolve a fixed
            set of domains.
        """
        mrc = MemoryReactorClock()
        drr = deterministicResolvingReactor(
            mrc,
            hostMap={
                "example.com": [EXAMPLE_COM_IP],
                "ipv6.example.com": [EXAMPLE_COM_V6_IP],
                "example.net": [EXAMPLE_NET_IP],
                "example.org": [EXAMPLE_ORG_IP],
                "foo": [FOO_LOCAL_IP],
                "foo.com": [FOO_COM_IP],
                "127.0.0.7": ["127.0.0.7"],
                "::7": ["::7"],
            },
        )

        # Lots of tests were written expecting MemoryReactorClock and the
        # reactor seen by the SUT to be the same object.
        drr.tcpClients = mrc.tcpClients
        drr.advance = mrc.advance
        return drr

    class StubEndpoint:
        """
        Endpoint that wraps existing endpoint, substitutes StubHTTPProtocol, and
        resulting protocol instances are attached to the given test case.
        """

        def __init__(self, endpoint, testCase):
            self.endpoint = endpoint
            self.testCase = testCase

            def nothing():
                """this function does nothing"""

            self.factory = _HTTP11ClientFactory(nothing, repr(self.endpoint))
            self.protocol = StubHTTPProtocol()
            self.factory.buildProtocol = lambda addr: self.protocol

        def connect(self, ignoredFactory):
            self.testCase.protocol = self.protocol
            self.endpoint.connect(self.factory)
            return succeed(self.protocol)

    def buildAgentForWrapperTest(self, reactor):
        """
        Return an Agent suitable for use in tests that wrap the Agent and want
        both a fake reactor and StubHTTPProtocol.
        """
        agent = client.Agent(reactor)
        _oldGetEndpoint = agent._getEndpoint
        agent._getEndpoint = lambda *args: (
            self.StubEndpoint(_oldGetEndpoint(*args), self)
        )
        return agent

    def connect(self, factory):
        """
        Fake implementation of an endpoint which synchronously
        succeeds with an instance of L{StubHTTPProtocol} for ease of
        testing.
        """
        protocol = StubHTTPProtocol()
        protocol.makeConnection(None)
        self.protocol = protocol
        return succeed(protocol)


class DummyEndpoint:
    """
    An endpoint that uses a fake transport.
    """

    def connect(self, factory):
        protocol = factory.buildProtocol(None)
        protocol.makeConnection(StringTransport())
        return succeed(protocol)


class BadEndpoint:
    """
    An endpoint that shouldn't be called.
    """

    def connect(self, factory):
        raise RuntimeError("This endpoint should not have been used.")


class DummyFactory(Factory):
    """
    Create C{StubHTTPProtocol} instances.
    """

    def __init__(self, quiescentCallback, metadata):
        pass

    protocol = StubHTTPProtocol


class HTTPConnectionPoolTests(TestCase, FakeReactorAndConnectMixin):
    """
    Tests for the L{HTTPConnectionPool} class.
    """

    def setUp(self):
        self.fakeReactor = self.createReactor()
        self.pool = HTTPConnectionPool(self.fakeReactor)
        self.pool._factory = DummyFactory
        # The retry code path is tested in HTTPConnectionPoolRetryTests:
        self.pool.retryAutomatically = False

    def test_getReturnsNewIfCacheEmpty(self):
        """
        If there are no cached connections,
        L{HTTPConnectionPool.getConnection} returns a new connection.
        """
        self.assertEqual(self.pool._connections, {})

        def gotConnection(conn):
            self.assertIsInstance(conn, StubHTTPProtocol)
            # The new connection is not stored in the pool:
            self.assertNotIn(conn, self.pool._connections.values())

        unknownKey = 12245
        d = self.pool.getConnection(unknownKey, DummyEndpoint())
        return d.addCallback(gotConnection)

    def test_putStartsTimeout(self):
        """
        If a connection is put back to the pool, a 240-sec timeout is started.

        When the timeout hits, the connection is closed and removed from the
        pool.
        """
        # We start out with one cached connection:
        protocol = StubHTTPProtocol()
        protocol.makeConnection(StringTransport())
        self.pool._putConnection(("http", b"example.com", 80), protocol)

        # Connection is in pool, still not closed:
        self.assertEqual(protocol.transport.disconnecting, False)
        self.assertIn(protocol, self.pool._connections[("http", b"example.com", 80)])

        # Advance 239 seconds, still not closed:
        self.fakeReactor.advance(239)
        self.assertEqual(protocol.transport.disconnecting, False)
        self.assertIn(protocol, self.pool._connections[("http", b"example.com", 80)])
        self.assertIn(protocol, self.pool._timeouts)

        # Advance past 240 seconds, connection will be closed:
        self.fakeReactor.advance(1.1)
        self.assertEqual(protocol.transport.disconnecting, True)
        self.assertNotIn(protocol, self.pool._connections[("http", b"example.com", 80)])
        self.assertNotIn(protocol, self.pool._timeouts)

    def test_putExceedsMaxPersistent(self):
        """
        If an idle connection is put back in the cache and the max number of
        persistent connections has been exceeded, one of the connections is
        closed and removed from the cache.
        """
        pool = self.pool

        # We start out with two cached connection, the max:
        origCached = [StubHTTPProtocol(), StubHTTPProtocol()]
        for p in origCached:
            p.makeConnection(StringTransport())
            pool._putConnection(("http", b"example.com", 80), p)
        self.assertEqual(pool._connections[("http", b"example.com", 80)], origCached)
        timeouts = pool._timeouts.copy()

        # Now we add another one:
        newProtocol = StubHTTPProtocol()
        newProtocol.makeConnection(StringTransport())
        pool._putConnection(("http", b"example.com", 80), newProtocol)

        # The oldest cached connections will be removed and disconnected:
        newCached = pool._connections[("http", b"example.com", 80)]
        self.assertEqual(len(newCached), 2)
        self.assertEqual(newCached, [origCached[1], newProtocol])
        self.assertEqual([p.transport.disconnecting for p in newCached], [False, False])
        self.assertEqual(origCached[0].transport.disconnecting, True)
        self.assertTrue(timeouts[origCached[0]].cancelled)
        self.assertNotIn(origCached[0], pool._timeouts)

    def test_maxPersistentPerHost(self):
        """
        C{maxPersistentPerHost} is enforced per C{(scheme, host, port)}:
        different keys have different max connections.
        """

        def addProtocol(scheme, host, port):
            p = StubHTTPProtocol()
            p.makeConnection(StringTransport())
            self.pool._putConnection((scheme, host, port), p)
            return p

        persistent = []
        persistent.append(addProtocol("http", b"example.com", 80))
        persistent.append(addProtocol("http", b"example.com", 80))
        addProtocol("https", b"example.com", 443)
        addProtocol("http", b"www2.example.com", 80)

        self.assertEqual(
            self.pool._connections[("http", b"example.com", 80)], persistent
        )
        self.assertEqual(len(self.pool._connections[("https", b"example.com", 443)]), 1)
        self.assertEqual(
            len(self.pool._connections[("http", b"www2.example.com", 80)]), 1
        )

    def test_getCachedConnection(self):
        """
        Getting an address which has a cached connection returns the cached
        connection, removes it from the cache and cancels its timeout.
        """
        # We start out with one cached connection:
        protocol = StubHTTPProtocol()
        protocol.makeConnection(StringTransport())
        self.pool._putConnection(("http", b"example.com", 80), protocol)

        def gotConnection(conn):
            # We got the cached connection:
            self.assertIdentical(protocol, conn)
            self.assertNotIn(conn, self.pool._connections[("http", b"example.com", 80)])
            # And the timeout was cancelled:
            self.fakeReactor.advance(241)
            self.assertEqual(conn.transport.disconnecting, False)
            self.assertNotIn(conn, self.pool._timeouts)

        return self.pool.getConnection(
            ("http", b"example.com", 80),
            BadEndpoint(),
        ).addCallback(gotConnection)

    def test_newConnection(self):
        """
        The pool's C{_newConnection} method constructs a new connection.
        """
        # We start out with one cached connection:
        protocol = StubHTTPProtocol()
        protocol.makeConnection(StringTransport())
        key = 12245
        self.pool._putConnection(key, protocol)

        def gotConnection(newConnection):
            # We got a new connection:
            self.assertNotIdentical(protocol, newConnection)
            # And the old connection is still there:
            self.assertIn(protocol, self.pool._connections[key])
            # While the new connection is not:
            self.assertNotIn(newConnection, self.pool._connections.values())

        d = self.pool._newConnection(key, DummyEndpoint())
        return d.addCallback(gotConnection)

    def test_getSkipsDisconnected(self):
        """
        When getting connections out of the cache, disconnected connections
        are removed and not returned.
        """
        pool = self.pool
        key = ("http", b"example.com", 80)

        # We start out with two cached connection, the max:
        origCached = [StubHTTPProtocol(), StubHTTPProtocol()]
        for p in origCached:
            p.makeConnection(StringTransport())
            pool._putConnection(key, p)
        self.assertEqual(pool._connections[key], origCached)

        # We close the first one:
        origCached[0].state = "DISCONNECTED"

        # Now, when we retrive connections we should get the *second* one:
        result = []
        self.pool.getConnection(key, BadEndpoint()).addCallback(result.append)
        self.assertIdentical(result[0], origCached[1])

        # And both the disconnected and removed connections should be out of
        # the cache:
        self.assertEqual(pool._connections[key], [])
        self.assertEqual(pool._timeouts, {})

    def test_putNotQuiescent(self):
        """
        If a non-quiescent connection is put back in the cache, an error is
        logged.
        """
        protocol = StubHTTPProtocol()
        # By default state is QUIESCENT
        self.assertEqual(protocol.state, "QUIESCENT")

        logObserver = EventLoggingObserver.createWithCleanup(self, globalLogPublisher)

        protocol.state = "NOTQUIESCENT"
        self.pool._putConnection(("http", b"example.com", 80), protocol)
        self.assertEquals(1, len(logObserver))

        event = logObserver[0]
        f = event["log_failure"]

        self.assertIsInstance(f.value, RuntimeError)
        self.assertEqual(
            f.getErrorMessage(), "BUG: Non-quiescent protocol added to connection pool."
        )
        self.assertIdentical(
            None, self.pool._connections.get(("http", b"example.com", 80))
        )
        self.flushLoggedErrors(RuntimeError)

    def test_getUsesQuiescentCallback(self):
        """
        When L{HTTPConnectionPool.getConnection} connects, it returns a
        C{Deferred} that fires with an instance of L{HTTP11ClientProtocol}
        that has the correct quiescent callback attached. When this callback
        is called the protocol is returned to the cache correctly, using the
        right key.
        """

        class StringEndpoint:
            def connect(self, factory):
                p = factory.buildProtocol(None)
                p.makeConnection(StringTransport())
                return succeed(p)

        pool = HTTPConnectionPool(self.fakeReactor, True)
        pool.retryAutomatically = False
        result = []
        key = "a key"
        pool.getConnection(key, StringEndpoint()).addCallback(result.append)
        protocol = result[0]
        self.assertIsInstance(protocol, HTTP11ClientProtocol)

        # Now that we have protocol instance, lets try to put it back in the
        # pool:
        protocol._state = "QUIESCENT"
        protocol._quiescentCallback(protocol)

        # If we try to retrive a connection to same destination again, we
        # should get the same protocol, because it should've been added back
        # to the pool:
        result2 = []
        pool.getConnection(key, StringEndpoint()).addCallback(result2.append)
        self.assertIdentical(result2[0], protocol)

    def test_closeCachedConnections(self):
        """
        L{HTTPConnectionPool.closeCachedConnections} closes all cached
        connections and removes them from the cache. It returns a Deferred
        that fires when they have all lost their connections.
        """
        persistent = []

        def addProtocol(scheme, host, port):
            p = HTTP11ClientProtocol()
            p.makeConnection(StringTransport())
            self.pool._putConnection((scheme, host, port), p)
            persistent.append(p)

        addProtocol("http", b"example.com", 80)
        addProtocol("http", b"www2.example.com", 80)
        doneDeferred = self.pool.closeCachedConnections()

        # Connections have begun disconnecting:
        for p in persistent:
            self.assertEqual(p.transport.disconnecting, True)
        self.assertEqual(self.pool._connections, {})
        # All timeouts were cancelled and removed:
        for dc in self.fakeReactor.getDelayedCalls():
            self.assertEqual(dc.cancelled, True)
        self.assertEqual(self.pool._timeouts, {})

        # Returned Deferred fires when all connections have been closed:
        result = []
        doneDeferred.addCallback(result.append)
        self.assertEqual(result, [])
        persistent[0].connectionLost(Failure(ConnectionDone()))
        self.assertEqual(result, [])
        persistent[1].connectionLost(Failure(ConnectionDone()))
        self.assertEqual(result, [None])

    def test_cancelGetConnectionCancelsEndpointConnect(self):
        """
        Cancelling the C{Deferred} returned from
        L{HTTPConnectionPool.getConnection} cancels the C{Deferred} returned
        by opening a new connection with the given endpoint.
        """
        self.assertEqual(self.pool._connections, {})
        connectionResult = Deferred()

        class Endpoint:
            def connect(self, factory):
                return connectionResult

        d = self.pool.getConnection(12345, Endpoint())
        d.cancel()
        self.assertEqual(self.failureResultOf(connectionResult).type, CancelledError)


class AgentTestsMixin:
    """
    Tests for any L{IAgent} implementation.
    """

    def test_interface(self):
        """
        The agent object provides L{IAgent}.
        """
        self.assertTrue(verifyObject(IAgent, self.makeAgent()))


class IntegrationTestingMixin:
    """
    Transport-to-Agent integration tests for both HTTP and HTTPS.
    """

    def test_integrationTestIPv4(self):
        """
        L{Agent} works over IPv4.
        """
        self.integrationTest(b"example.com", EXAMPLE_COM_IP, IPv4Address)

    def test_integrationTestIPv4Address(self):
        """
        L{Agent} works over IPv4 when hostname is an IPv4 address.
        """
        self.integrationTest(b"127.0.0.7", "127.0.0.7", IPv4Address)

    def test_integrationTestIPv6(self):
        """
        L{Agent} works over IPv6.
        """
        self.integrationTest(b"ipv6.example.com", EXAMPLE_COM_V6_IP, IPv6Address)

    def test_integrationTestIPv6Address(self):
        """
        L{Agent} works over IPv6 when hostname is an IPv6 address.
        """
        self.integrationTest(b"[::7]", "::7", IPv6Address)

    def integrationTest(
        self,
        hostName,
        expectedAddress,
        addressType,
        serverWrapper=lambda server: server,
        createAgent=client.Agent,
        scheme=b"http",
    ):
        """
        L{Agent} will make a TCP connection, send an HTTP request, and return a
        L{Deferred} that fires when the response has been received.

        @param hostName: The hostname to interpolate into the URL to be
            requested.
        @type hostName: L{bytes}

        @param expectedAddress: The expected address string.
        @type expectedAddress: L{bytes}

        @param addressType: The class to construct an address out of.
        @type addressType: L{type}

        @param serverWrapper: A callable that takes a protocol factory and
            returns a protocol factory; used to wrap the server / responder
            side in a TLS server.
        @type serverWrapper:
            serverWrapper(L{twisted.internet.interfaces.IProtocolFactory}) ->
            L{twisted.internet.interfaces.IProtocolFactory}

        @param createAgent: A callable that takes a reactor and produces an
            L{IAgent}; used to construct an agent with an appropriate trust
            root for TLS.
        @type createAgent: createAgent(reactor) -> L{IAgent}

        @param scheme: The scheme to test, C{http} or C{https}
        @type scheme: L{bytes}
        """
        reactor = self.createReactor()
        agent = createAgent(reactor)
        deferred = agent.request(b"GET", scheme + b"://" + hostName + b"/")
        host, port, factory, timeout, bind = reactor.tcpClients[0]
        self.assertEqual(host, expectedAddress)
        peerAddress = addressType("TCP", host, port)
        clientProtocol = factory.buildProtocol(peerAddress)
        clientTransport = FakeTransport(clientProtocol, False, peerAddress=peerAddress)
        clientProtocol.makeConnection(clientTransport)

        @Factory.forProtocol
        def accumulator():
            ap = AccumulatingProtocol()
            accumulator.currentProtocol = ap
            return ap

        accumulator.currentProtocol = None
        accumulator.protocolConnectionMade = None
        wrapper = serverWrapper(accumulator).buildProtocol(None)
        serverTransport = FakeTransport(wrapper, True)
        wrapper.makeConnection(serverTransport)
        pump = IOPump(clientProtocol, wrapper, clientTransport, serverTransport, False)
        pump.flush()
        self.assertNoResult(deferred)
        lines = accumulator.currentProtocol.data.split(b"\r\n")
        self.assertTrue(lines[0].startswith(b"GET / HTTP"), lines[0])
        headers = dict([line.split(b": ", 1) for line in lines[1:] if line])
        self.assertEqual(headers[b"Host"], hostName)
        self.assertNoResult(deferred)
        accumulator.currentProtocol.transport.write(
            b"HTTP/1.1 200 OK"
            b"\r\nX-An-Header: an-value\r\n"
            b"\r\nContent-length: 12\r\n\r\n"
            b"hello world!"
        )
        pump.flush()
        response = self.successResultOf(deferred)
        self.assertEquals(
            response.headers.getRawHeaders(b"x-an-header")[0], b"an-value"
        )


@implementer(IAgentEndpointFactory)
class StubEndpointFactory:
    """
    A stub L{IAgentEndpointFactory} for use in testing.
    """

    def endpointForURI(self, uri):
        """
        Testing implementation.

        @param uri: A L{URI}.

        @return: C{(scheme, host, port)} of passed in URI; violation of
            interface but useful for testing.
        @rtype: L{tuple}
        """
        return (uri.scheme, uri.host, uri.port)


class AgentTests(
    TestCase, FakeReactorAndConnectMixin, AgentTestsMixin, IntegrationTestingMixin
):
    """
    Tests for the new HTTP client API provided by L{Agent}.
    """

    def makeAgent(self):
        """
        @return: a new L{twisted.web.client.Agent} instance
        """
        return client.Agent(self.reactor)

    def setUp(self):
        """
        Create an L{Agent} wrapped around a fake reactor.
        """
        self.reactor = self.createReactor()
        self.agent = self.makeAgent()

    def test_defaultPool(self):
        """
        If no pool is passed in, the L{Agent} creates a non-persistent pool.
        """
        agent = client.Agent(self.reactor)
        self.assertIsInstance(agent._pool, HTTPConnectionPool)
        self.assertEqual(agent._pool.persistent, False)
        self.assertIdentical(agent._reactor, agent._pool._reactor)

    def test_persistent(self):
        """
        If C{persistent} is set to C{True} on the L{HTTPConnectionPool} (the
        default), C{Request}s are created with their C{persistent} flag set to
        C{True}.
        """
        pool = HTTPConnectionPool(self.reactor)
        agent = client.Agent(self.reactor, pool=pool)
        agent._getEndpoint = lambda *args: self
        agent.request(b"GET", b"http://127.0.0.1")
        self.assertEqual(self.protocol.requests[0][0].persistent, True)

    def test_nonPersistent(self):
        """
        If C{persistent} is set to C{False} when creating the
        L{HTTPConnectionPool}, C{Request}s are created with their
        C{persistent} flag set to C{False}.

        Elsewhere in the tests for the underlying HTTP code we ensure that
        this will result in the disconnection of the HTTP protocol once the
        request is done, so that the connection will not be returned to the
        pool.
        """
        pool = HTTPConnectionPool(self.reactor, persistent=False)
        agent = client.Agent(self.reactor, pool=pool)
        agent._getEndpoint = lambda *args: self
        agent.request(b"GET", b"http://127.0.0.1")
        self.assertEqual(self.protocol.requests[0][0].persistent, False)

    def test_connectUsesConnectionPool(self):
        """
        When a connection is made by the Agent, it uses its pool's
        C{getConnection} method to do so, with the endpoint returned by
        C{self._getEndpoint}. The key used is C{(scheme, host, port)}.
        """
        endpoint = DummyEndpoint()

        class MyAgent(client.Agent):
            def _getEndpoint(this, uri):
                self.assertEqual(
                    (uri.scheme, uri.host, uri.port), (b"http", b"foo", 80)
                )
                return endpoint

        class DummyPool:
            connected = False
            persistent = False

            def getConnection(this, key, ep):
                this.connected = True
                self.assertEqual(ep, endpoint)
                # This is the key the default Agent uses, others will have
                # different keys:
                self.assertEqual(key, (b"http", b"foo", 80))
                return defer.succeed(StubHTTPProtocol())

        pool = DummyPool()
        agent = MyAgent(self.reactor, pool=pool)
        self.assertIdentical(pool, agent._pool)

        headers = http_headers.Headers()
        headers.addRawHeader(b"host", b"foo")
        bodyProducer = object()
        agent.request(
            b"GET", b"http://foo/", bodyProducer=bodyProducer, headers=headers
        )
        self.assertEqual(agent._pool.connected, True)

    def test_nonBytesMethod(self):
        """
        L{Agent.request} raises L{TypeError} when the C{method} argument isn't
        L{bytes}.
        """
        self.assertRaises(TypeError, self.agent.request, "GET", b"http://foo.example/")

    def test_unsupportedScheme(self):
        """
        L{Agent.request} returns a L{Deferred} which fails with
        L{SchemeNotSupported} if the scheme of the URI passed to it is not
        C{'http'}.
        """
        return self.assertFailure(
            self.agent.request(b"GET", b"mailto:alice@example.com"), SchemeNotSupported
        )

    def test_connectionFailed(self):
        """
        The L{Deferred} returned by L{Agent.request} fires with a L{Failure} if
        the TCP connection attempt fails.
        """
        result = self.agent.request(b"GET", b"http://foo/")
        # Cause the connection to be refused
        host, port, factory = self.reactor.tcpClients.pop()[:3]
        factory.clientConnectionFailed(None, Failure(ConnectionRefusedError()))
        self.reactor.advance(10)
        # ^ https://twistedmatrix.com/trac/ticket/8202
        self.failureResultOf(result, ConnectionRefusedError)

    def test_connectHTTP(self):
        """
        L{Agent._getEndpoint} return a C{HostnameEndpoint} when passed a scheme
        of C{'http'}.
        """
        expectedHost = b"example.com"
        expectedPort = 1234
        endpoint = self.agent._getEndpoint(
            URI.fromBytes(b"http://%b:%d" % (expectedHost, expectedPort))
        )
        self.assertEqual(endpoint._hostStr, "example.com")
        self.assertEqual(endpoint._port, expectedPort)
        self.assertIsInstance(endpoint, HostnameEndpoint)

    def test_nonDecodableURI(self):
        """
        L{Agent._getEndpoint} when given a non-ASCII decodable URI will raise a
        L{ValueError} saying such.
        """
        uri = URI.fromBytes(b"http://example.com:80")
        uri.host = "\u2603.com".encode()

        with self.assertRaises(ValueError) as e:
            self.agent._getEndpoint(uri)

        self.assertEqual(
            e.exception.args[0],
            (
                "The host of the provided URI ({reprout}) contains "
                "non-ASCII octets, it should be ASCII "
                "decodable."
            ).format(reprout=repr(uri.host)),
        )

    def test_hostProvided(self):
        """
        If L{None} is passed to L{Agent.request} for the C{headers} parameter,
        a L{Headers} instance is created for the request and a I{Host} header
        added to it.
        """
        self.agent._getEndpoint = lambda *args: self
        self.agent.request(b"GET", b"http://example.com/foo?bar")

        req, res = self.protocol.requests.pop()
        self.assertEqual(req.headers.getRawHeaders(b"host"), [b"example.com"])

    def test_hostIPv6Bracketed(self):
        """
        If an IPv6 address is used in the C{uri} passed to L{Agent.request},
        the computed I{Host} header needs to be bracketed.
        """
        self.agent._getEndpoint = lambda *args: self
        self.agent.request(b"GET", b"http://[::1]/")

        req, res = self.protocol.requests.pop()
        self.assertEqual(req.headers.getRawHeaders(b"host"), [b"[::1]"])

    def test_hostOverride(self):
        """
        If the headers passed to L{Agent.request} includes a value for the
        I{Host} header, that value takes precedence over the one which would
        otherwise be automatically provided.
        """
        headers = http_headers.Headers({b"foo": [b"bar"], b"host": [b"quux"]})
        self.agent._getEndpoint = lambda *args: self
        self.agent.request(b"GET", b"http://example.com/foo?bar", headers)

        req, res = self.protocol.requests.pop()
        self.assertEqual(req.headers.getRawHeaders(b"host"), [b"quux"])

    def test_headersUnmodified(self):
        """
        If a I{Host} header must be added to the request, the L{Headers}
        instance passed to L{Agent.request} is not modified.
        """
        headers = http_headers.Headers()
        self.agent._getEndpoint = lambda *args: self
        self.agent.request(b"GET", b"http://example.com/foo", headers)

        protocol = self.protocol

        # The request should have been issued.
        self.assertEqual(len(protocol.requests), 1)
        # And the headers object passed in should not have changed.
        self.assertEqual(headers, http_headers.Headers())

    def test_hostValueStandardHTTP(self):
        """
        When passed a scheme of C{'http'} and a port of C{80},
        L{Agent._computeHostValue} returns a string giving just
        the host name passed to it.
        """
        self.assertEqual(
            self.agent._computeHostValue(b"http", b"example.com", 80), b"example.com"
        )

    def test_hostValueNonStandardHTTP(self):
        """
        When passed a scheme of C{'http'} and a port other than C{80},
        L{Agent._computeHostValue} returns a string giving the
        host passed to it joined together with the port number by C{":"}.
        """
        self.assertEqual(
            self.agent._computeHostValue(b"http", b"example.com", 54321),
            b"example.com:54321",
        )

    def test_hostValueStandardHTTPS(self):
        """
        When passed a scheme of C{'https'} and a port of C{443},
        L{Agent._computeHostValue} returns a string giving just
        the host name passed to it.
        """
        self.assertEqual(
            self.agent._computeHostValue(b"https", b"example.com", 443), b"example.com"
        )

    def test_hostValueNonStandardHTTPS(self):
        """
        When passed a scheme of C{'https'} and a port other than C{443},
        L{Agent._computeHostValue} returns a string giving the
        host passed to it joined together with the port number by C{":"}.
        """
        self.assertEqual(
            self.agent._computeHostValue(b"https", b"example.com", 54321),
            b"example.com:54321",
        )

    def test_request(self):
        """
        L{Agent.request} establishes a new connection to the host indicated by
        the host part of the URI passed to it and issues a request using the
        method, the path portion of the URI, the headers, and the body producer
        passed to it.  It returns a L{Deferred} which fires with an
        L{IResponse} from the server.
        """
        self.agent._getEndpoint = lambda *args: self

        headers = http_headers.Headers({b"foo": [b"bar"]})
        # Just going to check the body for identity, so it doesn't need to be
        # real.
        body = object()
        self.agent.request(b"GET", b"http://example.com:1234/foo?bar", headers, body)

        protocol = self.protocol

        # The request should be issued.
        self.assertEqual(len(protocol.requests), 1)
        req, res = protocol.requests.pop()
        self.assertIsInstance(req, Request)
        self.assertEqual(req.method, b"GET")
        self.assertEqual(req.uri, b"/foo?bar")
        self.assertEqual(
            req.headers,
            http_headers.Headers({b"foo": [b"bar"], b"host": [b"example.com:1234"]}),
        )
        self.assertIdentical(req.bodyProducer, body)

    def test_connectTimeout(self):
        """
        L{Agent} takes a C{connectTimeout} argument which is forwarded to the
        following C{connectTCP} agent.
        """
        agent = client.Agent(self.reactor, connectTimeout=5)
        agent.request(b"GET", b"http://foo/")
        timeout = self.reactor.tcpClients.pop()[3]
        self.assertEqual(5, timeout)

    @skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
    def test_connectTimeoutHTTPS(self):
        """
        L{Agent} takes a C{connectTimeout} argument which is forwarded to the
        following C{connectTCP} call.
        """
        agent = client.Agent(self.reactor, connectTimeout=5)
        agent.request(b"GET", b"https://foo/")
        timeout = self.reactor.tcpClients.pop()[3]
        self.assertEqual(5, timeout)

    def test_bindAddress(self):
        """
        L{Agent} takes a C{bindAddress} argument which is forwarded to the
        following C{connectTCP} call.
        """
        agent = client.Agent(self.reactor, bindAddress="192.168.0.1")
        agent.request(b"GET", b"http://foo/")
        address = self.reactor.tcpClients.pop()[4]
        self.assertEqual("192.168.0.1", address)

    @skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
    def test_bindAddressSSL(self):
        """
        L{Agent} takes a C{bindAddress} argument which is forwarded to the
        following C{connectSSL} call.
        """
        agent = client.Agent(self.reactor, bindAddress="192.168.0.1")
        agent.request(b"GET", b"https://foo/")
        address = self.reactor.tcpClients.pop()[4]
        self.assertEqual("192.168.0.1", address)

    def test_responseIncludesRequest(self):
        """
        L{Response}s returned by L{Agent.request} have a reference to the
        L{Request} that was originally issued.
        """
        uri = b"http://example.com/"
        agent = self.buildAgentForWrapperTest(self.reactor)
        d = agent.request(b"GET", uri)

        # The request should be issued.
        self.assertEqual(len(self.protocol.requests), 1)
        req, res = self.protocol.requests.pop()
        self.assertIsInstance(req, Request)

        resp = client.Response._construct(
            (b"HTTP", 1, 1), 200, b"OK", client.Headers({}), None, req
        )
        res.callback(resp)

        response = self.successResultOf(d)
        self.assertEqual(
            (
                response.request.method,
                response.request.absoluteURI,
                response.request.headers,
            ),
            (req.method, req.absoluteURI, req.headers),
        )

    def test_requestAbsoluteURI(self):
        """
        L{Request.absoluteURI} is the absolute URI of the request.
        """
        uri = b"http://example.com/foo;1234?bar#frag"
        agent = self.buildAgentForWrapperTest(self.reactor)
        agent.request(b"GET", uri)

        # The request should be issued.
        self.assertEqual(len(self.protocol.requests), 1)
        req, res = self.protocol.requests.pop()
        self.assertIsInstance(req, Request)
        self.assertEqual(req.absoluteURI, uri)

    def test_requestMissingAbsoluteURI(self):
        """
        L{Request.absoluteURI} is L{None} if L{Request._parsedURI} is L{None}.
        """
        request = client.Request(b"FOO", b"/", client.Headers(), None)
        self.assertIdentical(request.absoluteURI, None)

    def test_endpointFactory(self):
        """
        L{Agent.usingEndpointFactory} creates an L{Agent} that uses the given
        factory to create endpoints.
        """
        factory = StubEndpointFactory()
        agent = client.Agent.usingEndpointFactory(None, endpointFactory=factory)
        uri = URI.fromBytes(b"http://example.com/")
        returnedEndpoint = agent._getEndpoint(uri)
        self.assertEqual(returnedEndpoint, (b"http", b"example.com", 80))

    def test_endpointFactoryDefaultPool(self):
        """
        If no pool is passed in to L{Agent.usingEndpointFactory}, a default
        pool is constructed with no persistent connections.
        """
        agent = client.Agent.usingEndpointFactory(self.reactor, StubEndpointFactory())
        pool = agent._pool
        self.assertEqual(
            (pool.__class__, pool.persistent, pool._reactor),
            (HTTPConnectionPool, False, agent._reactor),
        )

    def test_endpointFactoryPool(self):
        """
        If a pool is passed in to L{Agent.usingEndpointFactory} it is used as
        the L{Agent} pool.
        """
        pool = object()
        agent = client.Agent.usingEndpointFactory(
            self.reactor, StubEndpointFactory(), pool
        )
        self.assertIs(pool, agent._pool)


class AgentMethodInjectionTests(
    FakeReactorAndConnectMixin,
    MethodInjectionTestsMixin,
    SynchronousTestCase,
):
    """
    Test L{client.Agent} against HTTP method injections.
    """

    def attemptRequestWithMaliciousMethod(self, method):
        """
        Attempt a request with the provided method.

        @param method: see L{MethodInjectionTestsMixin}
        """
        agent = client.Agent(self.createReactor())
        uri = b"http://twisted.invalid"
        agent.request(method, uri, client.Headers(), None)


class AgentURIInjectionTests(
    FakeReactorAndConnectMixin,
    URIInjectionTestsMixin,
    SynchronousTestCase,
):
    """
    Test L{client.Agent} against URI injections.
    """

    def attemptRequestWithMaliciousURI(self, uri):
        """
        Attempt a request with the provided method.

        @param uri: see L{URIInjectionTestsMixin}
        """
        agent = client.Agent(self.createReactor())
        method = b"GET"
        agent.request(method, uri, client.Headers(), None)


@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
class AgentHTTPSTests(TestCase, FakeReactorAndConnectMixin, IntegrationTestingMixin):
    """
    Tests for the new HTTP client API that depends on SSL.
    """

    def makeEndpoint(self, host=b"example.com", port=443):
        """
        Create an L{Agent} with an https scheme and return its endpoint
        created according to the arguments.

        @param host: The host for the endpoint.
        @type host: L{bytes}

        @param port: The port for the endpoint.
        @type port: L{int}

        @return: An endpoint of an L{Agent} constructed according to args.
        @rtype: L{SSL4ClientEndpoint}
        """
        return client.Agent(self.createReactor())._getEndpoint(
            URI.fromBytes(b"https://%b:%d/" % (host, port))
        )

    def test_endpointType(self):
        """
        L{Agent._getEndpoint} return a L{SSL4ClientEndpoint} when passed a
        scheme of C{'https'}.
        """
        from twisted.internet.endpoints import _WrapperEndpoint

        endpoint = self.makeEndpoint()
        self.assertIsInstance(endpoint, _WrapperEndpoint)
        self.assertIsInstance(endpoint._wrappedEndpoint, HostnameEndpoint)

    def test_hostArgumentIsRespected(self):
        """
        If a host is passed, the endpoint respects it.
        """
        endpoint = self.makeEndpoint(host=b"example.com")
        self.assertEqual(endpoint._wrappedEndpoint._hostStr, "example.com")

    def test_portArgumentIsRespected(self):
        """
        If a port is passed, the endpoint respects it.
        """
        expectedPort = 4321
        endpoint = self.makeEndpoint(port=expectedPort)
        self.assertEqual(endpoint._wrappedEndpoint._port, expectedPort)

    def test_contextFactoryType(self):
        """
        L{Agent} wraps its connection creator creator and uses modern TLS APIs.
        """
        endpoint = self.makeEndpoint()
        contextFactory = endpoint._wrapperFactory(None)._connectionCreator
        self.assertIsInstance(contextFactory, ClientTLSOptions)
        self.assertEqual(contextFactory._hostname, "example.com")

    def test_connectHTTPSCustomConnectionCreator(self):
        """
        If a custom L{WebClientConnectionCreator}-like object is passed to
        L{Agent.__init__} it will be used to determine the SSL parameters for
        HTTPS requests.  When an HTTPS request is made, the hostname and port
        number of the request URL will be passed to the connection creator's
        C{creatorForNetloc} method.  The resulting context object will be used
        to establish the SSL connection.
        """
        expectedHost = b"example.org"
        expectedPort = 20443

        class JustEnoughConnection:
            handshakeStarted = False
            connectState = False

            def do_handshake(self):
                """
                The handshake started.  Record that fact.
                """
                self.handshakeStarted = True

            def set_connect_state(self):
                """
                The connection started.  Record that fact.
                """
                self.connectState = True

        contextArgs = []

        @implementer(IOpenSSLClientConnectionCreator)
        class JustEnoughCreator:
            def __init__(self, hostname, port):
                self.hostname = hostname
                self.port = port

            def clientConnectionForTLS(self, tlsProtocol):
                """
                Implement L{IOpenSSLClientConnectionCreator}.

                @param tlsProtocol: The TLS protocol.
                @type tlsProtocol: L{TLSMemoryBIOProtocol}

                @return: C{expectedConnection}
                """
                contextArgs.append((tlsProtocol, self.hostname, self.port))
                return expectedConnection

        expectedConnection = JustEnoughConnection()

        @implementer(IPolicyForHTTPS)
        class StubBrowserLikePolicyForHTTPS:
            def creatorForNetloc(self, hostname, port):
                """
                Emulate L{BrowserLikePolicyForHTTPS}.

                @param hostname: The hostname to verify.
                @type hostname: L{bytes}

                @param port: The port number.
                @type port: L{int}

                @return: a stub L{IOpenSSLClientConnectionCreator}
                @rtype: L{JustEnoughCreator}
                """
                return JustEnoughCreator(hostname, port)

        expectedCreatorCreator = StubBrowserLikePolicyForHTTPS()
        reactor = self.createReactor()
        agent = client.Agent(reactor, expectedCreatorCreator)
        endpoint = agent._getEndpoint(
            URI.fromBytes(b"https://%b:%d" % (expectedHost, expectedPort))
        )
        endpoint.connect(Factory.forProtocol(Protocol))
        tlsFactory = reactor.tcpClients[-1][2]
        tlsProtocol = tlsFactory.buildProtocol(None)
        tlsProtocol.makeConnection(StringTransport())
        tls = contextArgs[0][0]
        self.assertIsInstance(tls, TLSMemoryBIOProtocol)
        self.assertEqual(contextArgs[0][1:], (expectedHost, expectedPort))
        self.assertTrue(expectedConnection.handshakeStarted)
        self.assertTrue(expectedConnection.connectState)

    def test_deprecatedDuckPolicy(self):
        """
        Passing something that duck-types I{like} a L{web client context
        factory <twisted.web.client.WebClientContextFactory>} - something that
        does not provide L{IPolicyForHTTPS} - to L{Agent} emits a
        L{DeprecationWarning} even if you don't actually C{import
        WebClientContextFactory} to do it.
        """

        def warnMe():
            client.Agent(
                deterministicResolvingReactor(MemoryReactorClock()),
                "does-not-provide-IPolicyForHTTPS",
            )

        warnMe()
        warnings = self.flushWarnings([warnMe])
        self.assertEqual(len(warnings), 1)
        [warning] = warnings
        self.assertEqual(warning["category"], DeprecationWarning)
        self.assertEqual(
            warning["message"],
            "'does-not-provide-IPolicyForHTTPS' was passed as the HTTPS "
            "policy for an Agent, but it does not provide IPolicyForHTTPS.  "
            "Since Twisted 14.0, you must pass a provider of IPolicyForHTTPS.",
        )

    def test_alternateTrustRoot(self):
        """
        L{BrowserLikePolicyForHTTPS.creatorForNetloc} returns an
        L{IOpenSSLClientConnectionCreator} provider which will add certificates
        from the given trust root.
        """
        trustRoot = CustomOpenSSLTrustRoot()
        policy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
        creator = policy.creatorForNetloc(b"thingy", 4321)
        self.assertTrue(trustRoot.called)
        connection = creator.clientConnectionForTLS(None)
        self.assertIs(trustRoot.context, connection.get_context())

    def integrationTest(self, hostName, expectedAddress, addressType):
        """
        Wrap L{AgentTestsMixin.integrationTest} with TLS.
        """
        certHostName = hostName.strip(b"[]")
        authority, server = certificatesForAuthorityAndServer(
            certHostName.decode("ascii")
        )

        def tlsify(serverFactory):
            return TLSMemoryBIOFactory(server.options(), False, serverFactory)

        def tlsagent(reactor):
            from zope.interface import implementer

            from twisted.web.iweb import IPolicyForHTTPS

            @implementer(IPolicyForHTTPS)
            class Policy:
                def creatorForNetloc(self, hostname, port):
                    return optionsForClientTLS(
                        hostname.decode("ascii"), trustRoot=authority
                    )

            return client.Agent(reactor, contextFactory=Policy())

        (
            super().integrationTest(
                hostName,
                expectedAddress,
                addressType,
                serverWrapper=tlsify,
                createAgent=tlsagent,
                scheme=b"https",
            )
        )


class WebClientContextFactoryTests(TestCase):
    """
    Tests for the context factory wrapper for web clients
    L{twisted.web.client.WebClientContextFactory}.
    """

    def setUp(self):
        """
        Get WebClientContextFactory while quashing its deprecation warning.
        """
        from twisted.web.client import WebClientContextFactory

        self.warned = self.flushWarnings([WebClientContextFactoryTests.setUp])
        self.webClientContextFactory = WebClientContextFactory

    def test_deprecated(self):
        """
        L{twisted.web.client.WebClientContextFactory} is deprecated.  Importing
        it displays a warning.
        """
        self.assertEqual(len(self.warned), 1)
        [warning] = self.warned
        self.assertEqual(warning["category"], DeprecationWarning)
        self.assertEqual(
            warning["message"],
            getDeprecationWarningString(
                self.webClientContextFactory,
                Version("Twisted", 14, 0, 0),
                replacement=BrowserLikePolicyForHTTPS,
            )
            # See https://twistedmatrix.com/trac/ticket/7242
            .replace(";", ":"),
        )

    @skipIf(sslPresent, "SSL Present.")
    def test_missingSSL(self):
        """
        If C{getContext} is called and SSL is not available, raise
        L{NotImplementedError}.
        """
        self.assertRaises(
            NotImplementedError,
            self.webClientContextFactory().getContext,
            b"example.com",
            443,
        )

    @skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
    def test_returnsContext(self):
        """
        If SSL is present, C{getContext} returns a L{OpenSSL.SSL.Context}.
        """
        ctx = self.webClientContextFactory().getContext("example.com", 443)
        self.assertIsInstance(ctx, ssl.SSL.Context)

    @skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
    def test_setsTrustRootOnContextToDefaultTrustRoot(self):
        """
        The L{CertificateOptions} has C{trustRoot} set to the default trust
        roots.
        """
        ctx = self.webClientContextFactory()
        certificateOptions = ctx._getCertificateOptions("example.com", 443)
        self.assertIsInstance(certificateOptions.trustRoot, ssl.OpenSSLDefaultPaths)


class HTTPConnectionPoolRetryTests(TestCase, FakeReactorAndConnectMixin):
    """
    L{client.HTTPConnectionPool}, by using
    L{client._RetryingHTTP11ClientProtocol}, supports retrying requests done
    against previously cached connections.
    """

    def test_onlyRetryIdempotentMethods(self):
        """
        Only GET, HEAD, OPTIONS, TRACE, DELETE methods cause a retry.
        """
        pool = client.HTTPConnectionPool(None)
        connection = client._RetryingHTTP11ClientProtocol(None, pool)
        self.assertTrue(connection._shouldRetry(b"GET", RequestNotSent(), None))
        self.assertTrue(connection._shouldRetry(b"HEAD", RequestNotSent(), None))
        self.assertTrue(connection._shouldRetry(b"OPTIONS", RequestNotSent(), None))
        self.assertTrue(connection._shouldRetry(b"TRACE", RequestNotSent(), None))
        self.assertTrue(connection._shouldRetry(b"DELETE", RequestNotSent(), None))
        self.assertFalse(connection._shouldRetry(b"POST", RequestNotSent(), None))
        self.assertFalse(connection._shouldRetry(b"MYMETHOD", RequestNotSent(), None))
        # This will be covered by a different ticket, since we need support
        # for resettable body producers:
        # self.assertTrue(connection._doRetry("PUT", RequestNotSent(), None))

    def test_onlyRetryIfNoResponseReceived(self):
        """
        Only L{RequestNotSent}, L{RequestTransmissionFailed} and
        L{ResponseNeverReceived} exceptions cause a retry.
        """
        pool = client.HTTPConnectionPool(None)
        connection = client._RetryingHTTP11ClientProtocol(None, pool)
        self.assertTrue(connection._shouldRetry(b"GET", RequestNotSent(), None))
        self.assertTrue(
            connection._shouldRetry(b"GET", RequestTransmissionFailed([]), None)
        )
        self.assertTrue(
            connection._shouldRetry(b"GET", ResponseNeverReceived([]), None)
        )
        self.assertFalse(connection._shouldRetry(b"GET", ResponseFailed([]), None))
        self.assertFalse(
            connection._shouldRetry(b"GET", ConnectionRefusedError(), None)
        )

    def test_dontRetryIfFailedDueToCancel(self):
        """
        If a request failed due to the operation being cancelled,
        C{_shouldRetry} returns C{False} to indicate the request should not be
        retried.
        """
        pool = client.HTTPConnectionPool(None)
        connection = client._RetryingHTTP11ClientProtocol(None, pool)
        exception = ResponseNeverReceived([Failure(defer.CancelledError())])
        self.assertFalse(connection._shouldRetry(b"GET", exception, None))

    def test_retryIfFailedDueToNonCancelException(self):
        """
        If a request failed with L{ResponseNeverReceived} due to some
        arbitrary exception, C{_shouldRetry} returns C{True} to indicate the
        request should be retried.
        """
        pool = client.HTTPConnectionPool(None)
        connection = client._RetryingHTTP11ClientProtocol(None, pool)
        self.assertTrue(
            connection._shouldRetry(
                b"GET", ResponseNeverReceived([Failure(Exception())]), None
            )
        )

    def test_wrappedOnPersistentReturned(self):
        """
        If L{client.HTTPConnectionPool.getConnection} returns a previously
        cached connection, it will get wrapped in a
        L{client._RetryingHTTP11ClientProtocol}.
        """
        pool = client.HTTPConnectionPool(Clock())

        # Add a connection to the cache:
        protocol = StubHTTPProtocol()
        protocol.makeConnection(StringTransport())
        pool._putConnection(123, protocol)

        # Retrieve it, it should come back wrapped in a
        # _RetryingHTTP11ClientProtocol:
        d = pool.getConnection(123, DummyEndpoint())

        def gotConnection(connection):
            self.assertIsInstance(connection, client._RetryingHTTP11ClientProtocol)
            self.assertIdentical(connection._clientProtocol, protocol)

        return d.addCallback(gotConnection)

    def test_notWrappedOnNewReturned(self):
        """
        If L{client.HTTPConnectionPool.getConnection} returns a new
        connection, it will be returned as is.
        """
        pool = client.HTTPConnectionPool(None)
        d = pool.getConnection(123, DummyEndpoint())

        def gotConnection(connection):
            # Don't want to use isinstance since potentially the wrapper might
            # subclass it at some point:
            self.assertIdentical(connection.__class__, HTTP11ClientProtocol)

        return d.addCallback(gotConnection)

    def retryAttempt(self, willWeRetry):
        """
        Fail a first request, possibly retrying depending on argument.
        """
        protocols = []

        def newProtocol():
            protocol = StubHTTPProtocol()
            protocols.append(protocol)
            return defer.succeed(protocol)

        bodyProducer = object()
        request = client.Request(
            b"FOO", b"/", client.Headers(), bodyProducer, persistent=True
        )
        newProtocol()
        protocol = protocols[0]
        retrier = client._RetryingHTTP11ClientProtocol(protocol, newProtocol)

        def _shouldRetry(m, e, bp):
            self.assertEqual(m, b"FOO")
            self.assertIdentical(bp, bodyProducer)
            self.assertIsInstance(e, (RequestNotSent, ResponseNeverReceived))
            return willWeRetry

        retrier._shouldRetry = _shouldRetry

        d = retrier.request(request)

        # So far, one request made:
        self.assertEqual(len(protocols), 1)
        self.assertEqual(len(protocols[0].requests), 1)

        # Fail the first request:
        protocol.requests[0][1].errback(RequestNotSent())
        return d, protocols

    def test_retryIfShouldRetryReturnsTrue(self):
        """
        L{client._RetryingHTTP11ClientProtocol} retries when
        L{client._RetryingHTTP11ClientProtocol._shouldRetry} returns C{True}.
        """
        d, protocols = self.retryAttempt(True)
        # We retried!
        self.assertEqual(len(protocols), 2)
        response = object()
        protocols[1].requests[0][1].callback(response)
        return d.addCallback(self.assertIdentical, response)

    def test_dontRetryIfShouldRetryReturnsFalse(self):
        """
        L{client._RetryingHTTP11ClientProtocol} does not retry when
        L{client._RetryingHTTP11ClientProtocol._shouldRetry} returns C{False}.
        """
        d, protocols = self.retryAttempt(False)
        # We did not retry:
        self.assertEqual(len(protocols), 1)
        return self.assertFailure(d, RequestNotSent)

    def test_onlyRetryWithoutBody(self):
        """
        L{_RetryingHTTP11ClientProtocol} only retries queries that don't have
        a body.

        This is an implementation restriction; if the restriction is fixed,
        this test should be removed and PUT added to list of methods that
        support retries.
        """
        pool = client.HTTPConnectionPool(None)
        connection = client._RetryingHTTP11ClientProtocol(None, pool)
        self.assertTrue(connection._shouldRetry(b"GET", RequestNotSent(), None))
        self.assertFalse(connection._shouldRetry(b"GET", RequestNotSent(), object()))

    def test_onlyRetryOnce(self):
        """
        If a L{client._RetryingHTTP11ClientProtocol} fails more than once on
        an idempotent query before a response is received, it will not retry.
        """
        d, protocols = self.retryAttempt(True)
        self.assertEqual(len(protocols), 2)
        # Fail the second request too:
        protocols[1].requests[0][1].errback(ResponseNeverReceived([]))
        # We didn't retry again:
        self.assertEqual(len(protocols), 2)
        return self.assertFailure(d, ResponseNeverReceived)

    def test_dontRetryIfRetryAutomaticallyFalse(self):
        """
        If L{HTTPConnectionPool.retryAutomatically} is set to C{False}, don't
        wrap connections with retrying logic.
        """
        pool = client.HTTPConnectionPool(Clock())
        pool.retryAutomatically = False

        # Add a connection to the cache:
        protocol = StubHTTPProtocol()
        protocol.makeConnection(StringTransport())
        pool._putConnection(123, protocol)

        # Retrieve it, it should come back unwrapped:
        d = pool.getConnection(123, DummyEndpoint())

        def gotConnection(connection):
            self.assertIdentical(connection, protocol)

        return d.addCallback(gotConnection)

    def test_retryWithNewConnection(self):
        """
        L{client.HTTPConnectionPool} creates
        {client._RetryingHTTP11ClientProtocol} with a new connection factory
        method that creates a new connection using the same key and endpoint
        as the wrapped connection.
        """
        pool = client.HTTPConnectionPool(Clock())
        key = 123
        endpoint = DummyEndpoint()
        newConnections = []

        # Override the pool's _newConnection:
        def newConnection(k, e):
            newConnections.append((k, e))

        pool._newConnection = newConnection

        # Add a connection to the cache:
        protocol = StubHTTPProtocol()
        protocol.makeConnection(StringTransport())
        pool._putConnection(key, protocol)

        # Retrieve it, it should come back wrapped in a
        # _RetryingHTTP11ClientProtocol:
        d = pool.getConnection(key, endpoint)

        def gotConnection(connection):
            self.assertIsInstance(connection, client._RetryingHTTP11ClientProtocol)
            self.assertIdentical(connection._clientProtocol, protocol)
            # Verify that the _newConnection method on retrying connection
            # calls _newConnection on the pool:
            self.assertEqual(newConnections, [])
            connection._newConnection()
            self.assertEqual(len(newConnections), 1)
            self.assertEqual(newConnections[0][0], key)
            self.assertIdentical(newConnections[0][1], endpoint)

        return d.addCallback(gotConnection)


class CookieTestsMixin:
    """
    Mixin for unit tests dealing with cookies.
    """

    def addCookies(self, cookieJar, uri, cookies):
        """
        Add a cookie to a cookie jar.
        """
        response = client._FakeUrllib2Response(
            client.Response(
                (b"HTTP", 1, 1),
                200,
                b"OK",
                client.Headers({b"Set-Cookie": cookies}),
                None,
            )
        )
        request = client._FakeUrllib2Request(uri)
        cookieJar.extract_cookies(response, request)
        return request, response


class CookieJarTests(TestCase, CookieTestsMixin):
    """
    Tests for L{twisted.web.client._FakeUrllib2Response} and
    L{twisted.web.client._FakeUrllib2Request}'s interactions with
    L{CookieJar} instances.
    """

    def makeCookieJar(self):
        """
        @return: a L{CookieJar} with some sample cookies
        """
        cookieJar = CookieJar()
        reqres = self.addCookies(
            cookieJar,
            b"http://example.com:1234/foo?bar",
            [b"foo=1; cow=moo; Path=/foo; Comment=hello", b"bar=2; Comment=goodbye"],
        )
        return cookieJar, reqres

    def test_extractCookies(self):
        """
        L{CookieJar.extract_cookies} extracts cookie information from
        fake urllib2 response instances.
        """
        jar = self.makeCookieJar()[0]
        cookies = {c.name: c for c in jar}

        cookie = cookies["foo"]
        self.assertEqual(cookie.version, 0)
        self.assertEqual(cookie.name, "foo")
        self.assertEqual(cookie.value, "1")
        self.assertEqual(cookie.path, "/foo")
        self.assertEqual(cookie.comment, "hello")
        self.assertEqual(cookie.get_nonstandard_attr("cow"), "moo")

        cookie = cookies["bar"]
        self.assertEqual(cookie.version, 0)
        self.assertEqual(cookie.name, "bar")
        self.assertEqual(cookie.value, "2")
        self.assertEqual(cookie.path, "/")
        self.assertEqual(cookie.comment, "goodbye")
        self.assertIdentical(cookie.get_nonstandard_attr("cow"), None)

    def test_sendCookie(self):
        """
        L{CookieJar.add_cookie_header} adds a cookie header to a fake
        urllib2 request instance.
        """
        jar, (request, response) = self.makeCookieJar()

        self.assertIdentical(request.get_header("Cookie", None), None)

        jar.add_cookie_header(request)
        self.assertEqual(request.get_header("Cookie", None), "foo=1; bar=2")


class CookieAgentTests(
    TestCase, CookieTestsMixin, FakeReactorAndConnectMixin, AgentTestsMixin
):
    """
    Tests for L{twisted.web.client.CookieAgent}.
    """

    def makeAgent(self):
        """
        @return: a new L{twisted.web.client.CookieAgent}
        """
        return client.CookieAgent(
            self.buildAgentForWrapperTest(self.reactor), CookieJar()
        )

    def setUp(self):
        self.reactor = self.createReactor()

    def test_emptyCookieJarRequest(self):
        """
        L{CookieAgent.request} does not insert any C{'Cookie'} header into the
        L{Request} object if there is no cookie in the cookie jar for the URI
        being requested. Cookies are extracted from the response and stored in
        the cookie jar.
        """
        cookieJar = CookieJar()
        self.assertEqual(list(cookieJar), [])

        agent = self.buildAgentForWrapperTest(self.reactor)
        cookieAgent = client.CookieAgent(agent, cookieJar)
        d = cookieAgent.request(b"GET", b"http://example.com:1234/foo?bar")

        def _checkCookie(ignored):
            cookies = list(cookieJar)
            self.assertEqual(len(cookies), 1)
            self.assertEqual(cookies[0].name, "foo")
            self.assertEqual(cookies[0].value, "1")

        d.addCallback(_checkCookie)

        req, res = self.protocol.requests.pop()
        self.assertIdentical(req.headers.getRawHeaders(b"cookie"), None)

        resp = client.Response(
            (b"HTTP", 1, 1),
            200,
            b"OK",
            client.Headers(
                {
                    b"Set-Cookie": [
                        b"foo=1",
                    ]
                }
            ),
            None,
        )
        res.callback(resp)

        return d

    def test_requestWithCookie(self):
        """
        L{CookieAgent.request} inserts a C{'Cookie'} header into the L{Request}
        object when there is a cookie matching the request URI in the cookie
        jar.
        """
        uri = b"http://example.com:1234/foo?bar"
        cookie = b"foo=1"

        cookieJar = CookieJar()
        self.addCookies(cookieJar, uri, [cookie])
        self.assertEqual(len(list(cookieJar)), 1)

        agent = self.buildAgentForWrapperTest(self.reactor)
        cookieAgent = client.CookieAgent(agent, cookieJar)
        cookieAgent.request(b"GET", uri)

        req, res = self.protocol.requests.pop()
        self.assertEqual(req.headers.getRawHeaders(b"cookie"), [cookie])

    @skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
    def test_secureCookie(self):
        """
        L{CookieAgent} is able to handle secure cookies, ie cookies which
        should only be handled over https.
        """
        uri = b"https://example.com:1234/foo?bar"
        cookie = b"foo=1;secure"

        cookieJar = CookieJar()
        self.addCookies(cookieJar, uri, [cookie])
        self.assertEqual(len(list(cookieJar)), 1)

        agent = self.buildAgentForWrapperTest(self.reactor)
        cookieAgent = client.CookieAgent(agent, cookieJar)
        cookieAgent.request(b"GET", uri)

        req, res = self.protocol.requests.pop()
        self.assertEqual(req.headers.getRawHeaders(b"cookie"), [b"foo=1"])

    def test_secureCookieOnInsecureConnection(self):
        """
        If a cookie is setup as secure, it won't be sent with the request if
        it's not over HTTPS.
        """
        uri = b"http://example.com/foo?bar"
        cookie = b"foo=1;secure"

        cookieJar = CookieJar()
        self.addCookies(cookieJar, uri, [cookie])
        self.assertEqual(len(list(cookieJar)), 1)

        agent = self.buildAgentForWrapperTest(self.reactor)
        cookieAgent = client.CookieAgent(agent, cookieJar)
        cookieAgent.request(b"GET", uri)

        req, res = self.protocol.requests.pop()
        self.assertIdentical(None, req.headers.getRawHeaders(b"cookie"))

    def test_portCookie(self):
        """
        L{CookieAgent} supports cookies which enforces the port number they
        need to be transferred upon.
        """
        uri = b"http://example.com:1234/foo?bar"
        cookie = b"foo=1;port=1234"

        cookieJar = CookieJar()
        self.addCookies(cookieJar, uri, [cookie])
        self.assertEqual(len(list(cookieJar)), 1)

        agent = self.buildAgentForWrapperTest(self.reactor)
        cookieAgent = client.CookieAgent(agent, cookieJar)
        cookieAgent.request(b"GET", uri)

        req, res = self.protocol.requests.pop()
        self.assertEqual(req.headers.getRawHeaders(b"cookie"), [b"foo=1"])

    def test_portCookieOnWrongPort(self):
        """
        When creating a cookie with a port directive, it won't be added to the
        L{cookie.CookieJar} if the URI is on a different port.
        """
        uri = b"http://example.com:4567/foo?bar"
        cookie = b"foo=1;port=1234"

        cookieJar = CookieJar()
        self.addCookies(cookieJar, uri, [cookie])
        self.assertEqual(len(list(cookieJar)), 0)


class Decoder1(proxyForInterface(IResponse)):  # type: ignore[misc]
    """
    A test decoder to be used by L{client.ContentDecoderAgent} tests.
    """


class Decoder2(Decoder1):
    """
    A test decoder to be used by L{client.ContentDecoderAgent} tests.
    """


class ContentDecoderAgentTests(TestCase, FakeReactorAndConnectMixin, AgentTestsMixin):
    """
    Tests for L{client.ContentDecoderAgent}.
    """

    def makeAgent(self):
        """
        @return: a new L{twisted.web.client.ContentDecoderAgent}
        """
        return client.ContentDecoderAgent(self.agent, [])

    def setUp(self):
        """
        Create an L{Agent} wrapped around a fake reactor.
        """
        self.reactor = self.createReactor()
        self.agent = self.buildAgentForWrapperTest(self.reactor)

    def test_acceptHeaders(self):
        """
        L{client.ContentDecoderAgent} sets the I{Accept-Encoding} header to the
        names of the available decoder objects.
        """
        agent = client.ContentDecoderAgent(
            self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
        )

        agent.request(b"GET", b"http://example.com/foo")

        protocol = self.protocol

        self.assertEqual(len(protocol.requests), 1)
        req, res = protocol.requests.pop()
        self.assertEqual(
            req.headers.getRawHeaders(b"accept-encoding"), [b"decoder1,decoder2"]
        )

    def test_existingHeaders(self):
        """
        If there are existing I{Accept-Encoding} fields,
        L{client.ContentDecoderAgent} creates a new field for the decoders it
        knows about.
        """
        headers = http_headers.Headers(
            {b"foo": [b"bar"], b"accept-encoding": [b"fizz"]}
        )
        agent = client.ContentDecoderAgent(
            self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
        )
        agent.request(b"GET", b"http://example.com/foo", headers=headers)

        protocol = self.protocol

        self.assertEqual(len(protocol.requests), 1)
        req, res = protocol.requests.pop()
        self.assertEqual(
            list(sorted(req.headers.getAllRawHeaders())),
            [
                (b"Accept-Encoding", [b"fizz", b"decoder1,decoder2"]),
                (b"Foo", [b"bar"]),
                (b"Host", [b"example.com"]),
            ],
        )

    def test_plainEncodingResponse(self):
        """
        If the response is not encoded despited the request I{Accept-Encoding}
        headers, L{client.ContentDecoderAgent} simply forwards the response.
        """
        agent = client.ContentDecoderAgent(
            self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
        )
        deferred = agent.request(b"GET", b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        response = Response((b"HTTP", 1, 1), 200, b"OK", http_headers.Headers(), None)
        res.callback(response)

        return deferred.addCallback(self.assertIdentical, response)

    def test_unsupportedEncoding(self):
        """
        If an encoding unknown to the L{client.ContentDecoderAgent} is found,
        the response is unchanged.
        """
        agent = client.ContentDecoderAgent(
            self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
        )
        deferred = agent.request(b"GET", b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers(
            {b"foo": [b"bar"], b"content-encoding": [b"fizz"]}
        )
        response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
        res.callback(response)

        return deferred.addCallback(self.assertIdentical, response)

    def test_unknownEncoding(self):
        """
        When L{client.ContentDecoderAgent} encounters a decoder it doesn't know
        about, it stops decoding even if another encoding is known afterwards.
        """
        agent = client.ContentDecoderAgent(
            self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
        )
        deferred = agent.request(b"GET", b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers(
            {b"foo": [b"bar"], b"content-encoding": [b"decoder1,fizz,decoder2"]}
        )
        response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
        res.callback(response)

        def check(result):
            self.assertNotIdentical(response, result)
            self.assertIsInstance(result, Decoder2)
            self.assertEqual(
                [b"decoder1,fizz"], result.headers.getRawHeaders(b"content-encoding")
            )

        return deferred.addCallback(check)


class SimpleAgentProtocol(Protocol):
    """
    A L{Protocol} to be used with an L{client.Agent} to receive data.

    @ivar finished: L{Deferred} firing when C{connectionLost} is called.

    @ivar made: L{Deferred} firing when C{connectionMade} is called.

    @ivar received: C{list} of received data.
    """

    def __init__(self):
        self.made = Deferred()
        self.finished = Deferred()
        self.received = []

    def connectionMade(self):
        self.made.callback(None)

    def connectionLost(self, reason):
        self.finished.callback(None)

    def dataReceived(self, data):
        self.received.append(data)


class ContentDecoderAgentWithGzipTests(TestCase, FakeReactorAndConnectMixin):
    def setUp(self):
        """
        Create an L{Agent} wrapped around a fake reactor.
        """
        self.reactor = self.createReactor()
        agent = self.buildAgentForWrapperTest(self.reactor)
        self.agent = client.ContentDecoderAgent(agent, [(b"gzip", client.GzipDecoder)])

    def test_gzipEncodingResponse(self):
        """
        If the response has a C{gzip} I{Content-Encoding} header,
        L{GzipDecoder} wraps the response to return uncompressed data to the
        user.
        """
        deferred = self.agent.request(b"GET", b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers(
            {b"foo": [b"bar"], b"content-encoding": [b"gzip"]}
        )
        transport = StringTransport()
        response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
        response.length = 12
        res.callback(response)

        compressor = zlib.compressobj(2, zlib.DEFLATED, 16 + zlib.MAX_WBITS)
        data = (
            compressor.compress(b"x" * 6)
            + compressor.compress(b"y" * 4)
            + compressor.flush()
        )

        def checkResponse(result):
            self.assertNotIdentical(result, response)
            self.assertEqual(result.version, (b"HTTP", 1, 1))
            self.assertEqual(result.code, 200)
            self.assertEqual(result.phrase, b"OK")
            self.assertEqual(
                list(result.headers.getAllRawHeaders()), [(b"Foo", [b"bar"])]
            )
            self.assertEqual(result.length, UNKNOWN_LENGTH)
            self.assertRaises(AttributeError, getattr, result, "unknown")

            response._bodyDataReceived(data[:5])
            response._bodyDataReceived(data[5:])
            response._bodyDataFinished()

            protocol = SimpleAgentProtocol()
            result.deliverBody(protocol)

            self.assertEqual(protocol.received, [b"x" * 6 + b"y" * 4])
            return defer.gatherResults([protocol.made, protocol.finished])

        deferred.addCallback(checkResponse)

        return deferred

    def test_brokenContent(self):
        """
        If the data received by the L{GzipDecoder} isn't valid gzip-compressed
        data, the call to C{deliverBody} fails with a C{zlib.error}.
        """
        deferred = self.agent.request(b"GET", b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers(
            {b"foo": [b"bar"], b"content-encoding": [b"gzip"]}
        )
        transport = StringTransport()
        response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
        response.length = 12
        res.callback(response)

        data = b"not gzipped content"

        def checkResponse(result):
            response._bodyDataReceived(data)

            result.deliverBody(Protocol())

        deferred.addCallback(checkResponse)
        self.assertFailure(deferred, client.ResponseFailed)

        def checkFailure(error):
            error.reasons[0].trap(zlib.error)
            self.assertIsInstance(error.response, Response)

        return deferred.addCallback(checkFailure)

    def test_flushData(self):
        """
        When the connection with the server is lost, the gzip protocol calls
        C{flush} on the zlib decompressor object to get uncompressed data which
        may have been buffered.
        """

        class decompressobj:
            def __init__(self, wbits):
                pass

            def decompress(self, data):
                return b"x"

            def flush(self):
                return b"y"

        oldDecompressObj = zlib.decompressobj
        zlib.decompressobj = decompressobj
        self.addCleanup(setattr, zlib, "decompressobj", oldDecompressObj)

        deferred = self.agent.request(b"GET", b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers({b"content-encoding": [b"gzip"]})
        transport = StringTransport()
        response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
        res.callback(response)

        def checkResponse(result):
            response._bodyDataReceived(b"data")
            response._bodyDataFinished()

            protocol = SimpleAgentProtocol()
            result.deliverBody(protocol)

            self.assertEqual(protocol.received, [b"x", b"y"])
            return defer.gatherResults([protocol.made, protocol.finished])

        deferred.addCallback(checkResponse)

        return deferred

    def test_flushError(self):
        """
        If the C{flush} call in C{connectionLost} fails, the C{zlib.error}
        exception is caught and turned into a L{ResponseFailed}.
        """

        class decompressobj:
            def __init__(self, wbits):
                pass

            def decompress(self, data):
                return b"x"

            def flush(self):
                raise zlib.error()

        oldDecompressObj = zlib.decompressobj
        zlib.decompressobj = decompressobj
        self.addCleanup(setattr, zlib, "decompressobj", oldDecompressObj)

        deferred = self.agent.request(b"GET", b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers({b"content-encoding": [b"gzip"]})
        transport = StringTransport()
        response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
        res.callback(response)

        def checkResponse(result):
            response._bodyDataReceived(b"data")
            response._bodyDataFinished()

            protocol = SimpleAgentProtocol()
            result.deliverBody(protocol)

            self.assertEqual(protocol.received, [b"x", b"y"])
            return defer.gatherResults([protocol.made, protocol.finished])

        deferred.addCallback(checkResponse)

        self.assertFailure(deferred, client.ResponseFailed)

        def checkFailure(error):
            error.reasons[1].trap(zlib.error)
            self.assertIsInstance(error.response, Response)

        return deferred.addCallback(checkFailure)


class ProxyAgentTests(TestCase, FakeReactorAndConnectMixin, AgentTestsMixin):
    """
    Tests for L{client.ProxyAgent}.
    """

    def makeAgent(self):
        """
        @return: a new L{twisted.web.client.ProxyAgent}
        """
        return client.ProxyAgent(
            TCP4ClientEndpoint(self.reactor, "127.0.0.1", 1234), self.reactor
        )

    def setUp(self):
        self.reactor = self.createReactor()
        self.agent = client.ProxyAgent(
            TCP4ClientEndpoint(self.reactor, "bar", 5678), self.reactor
        )
        oldEndpoint = self.agent._proxyEndpoint
        self.agent._proxyEndpoint = self.StubEndpoint(oldEndpoint, self)

    def test_nonBytesMethod(self):
        """
        L{ProxyAgent.request} raises L{TypeError} when the C{method} argument
        isn't L{bytes}.
        """
        self.assertRaises(TypeError, self.agent.request, "GET", b"http://foo.example/")

    def test_proxyRequest(self):
        """
        L{client.ProxyAgent} issues an HTTP request against the proxy, with the
        full URI as path, when C{request} is called.
        """
        headers = http_headers.Headers({b"foo": [b"bar"]})
        # Just going to check the body for identity, so it doesn't need to be
        # real.
        body = object()
        self.agent.request(b"GET", b"http://example.com:1234/foo?bar", headers, body)

        host, port, factory = self.reactor.tcpClients.pop()[:3]
        self.assertEqual(host, "bar")
        self.assertEqual(port, 5678)

        self.assertIsInstance(factory._wrappedFactory, client._HTTP11ClientFactory)

        protocol = self.protocol

        # The request should be issued.
        self.assertEqual(len(protocol.requests), 1)
        req, res = protocol.requests.pop()
        self.assertIsInstance(req, Request)
        self.assertEqual(req.method, b"GET")
        self.assertEqual(req.uri, b"http://example.com:1234/foo?bar")
        self.assertEqual(
            req.headers,
            http_headers.Headers({b"foo": [b"bar"], b"host": [b"example.com:1234"]}),
        )
        self.assertIdentical(req.bodyProducer, body)

    def test_nonPersistent(self):
        """
        C{ProxyAgent} connections are not persistent by default.
        """
        self.assertEqual(self.agent._pool.persistent, False)

    def test_connectUsesConnectionPool(self):
        """
        When a connection is made by the C{ProxyAgent}, it uses its pool's
        C{getConnection} method to do so, with the endpoint it was constructed
        with and a key of C{("http-proxy", endpoint)}.
        """
        endpoint = DummyEndpoint()

        class DummyPool:
            connected = False
            persistent = False

            def getConnection(this, key, ep):
                this.connected = True
                self.assertIdentical(ep, endpoint)
                # The key is *not* tied to the final destination, but only to
                # the address of the proxy, since that's where *we* are
                # connecting:
                self.assertEqual(key, ("http-proxy", endpoint))
                return defer.succeed(StubHTTPProtocol())

        pool = DummyPool()
        agent = client.ProxyAgent(endpoint, self.reactor, pool=pool)
        self.assertIdentical(pool, agent._pool)

        agent.request(b"GET", b"http://foo/")
        self.assertEqual(agent._pool.connected, True)


SENSITIVE_HEADERS = [
    b"authorization",
    b"cookie",
    b"cookie2",
    b"proxy-authorization",
    b"www-authenticate",
]


class _RedirectAgentTestsMixin(testMixinClass):
    """
    Test cases mixin for L{RedirectAgentTests} and
    L{BrowserLikeRedirectAgentTests}.
    """

    agent: IAgent
    reactor: MemoryReactorClock
    protocol: StubHTTPProtocol

    def test_noRedirect(self):
        """
        L{client.RedirectAgent} behaves like L{client.Agent} if the response
        doesn't contain a redirect.
        """
        deferred = self.agent.request(b"GET", b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers()
        response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
        res.callback(response)

        self.assertEqual(0, len(self.protocol.requests))
        result = self.successResultOf(deferred)
        self.assertIdentical(response, result)
        self.assertIdentical(result.previousResponse, None)

    def _testRedirectDefault(
        self,
        code: int,
        crossScheme: bool = False,
        crossDomain: bool = False,
        crossPort: bool = False,
        requestHeaders: Optional[Headers] = None,
    ) -> Request:
        """
        When getting a redirect, L{client.RedirectAgent} follows the URL
        specified in the L{Location} header field and make a new request.

        @param code: HTTP status code.
        """
        startDomain = b"example.com"
        startScheme = b"https" if ssl is not None else b"http"
        startPort = 80 if startScheme == b"http" else 443
        self.agent.request(
            b"GET", startScheme + b"://" + startDomain + b"/foo", headers=requestHeaders
        )

        host, port = self.reactor.tcpClients.pop()[:2]
        self.assertEqual(EXAMPLE_COM_IP, host)
        self.assertEqual(startPort, port)

        req, res = self.protocol.requests.pop()

        # If possible (i.e.: TLS support is present), run the test with a
        # cross-scheme redirect to verify that the scheme is honored; if not,
        # let's just make sure it works at all.

        targetScheme = startScheme
        targetDomain = startDomain
        targetPort = startPort

        if crossScheme:
            if ssl is None:
                raise SkipTest(
                    "Cross-scheme redirects can't be tested without TLS support."
                )
            targetScheme = b"https" if startScheme == b"http" else b"http"
            targetPort = 443 if startPort == 80 else 80

        portSyntax = b""
        if crossPort:
            targetPort = 8443
            portSyntax = b":8443"
        targetDomain = b"example.net" if crossDomain else startDomain
        locationValue = targetScheme + b"://" + targetDomain + portSyntax + b"/bar"
        headers = http_headers.Headers({b"location": [locationValue]})
        response = Response((b"HTTP", 1, 1), code, b"OK", headers, None)
        res.callback(response)

        req2, res2 = self.protocol.requests.pop()
        self.assertEqual(b"GET", req2.method)
        self.assertEqual(b"/bar", req2.uri)

        host, port = self.reactor.tcpClients.pop()[:2]
        self.assertEqual(EXAMPLE_NET_IP if crossDomain else EXAMPLE_COM_IP, host)
        self.assertEqual(targetPort, port)
        return req2

    def test_redirect301(self):
        """
        L{client.RedirectAgent} follows redirects on status code 301.
        """
        self._testRedirectDefault(301)

    def test_redirect301Scheme(self):
        """
        L{client.RedirectAgent} follows cross-scheme redirects.
        """
        self._testRedirectDefault(
            301,
            crossScheme=True,
        )

    def test_redirect302(self):
        """
        L{client.RedirectAgent} follows redirects on status code 302.
        """
        self._testRedirectDefault(302)

    def test_redirect307(self):
        """
        L{client.RedirectAgent} follows redirects on status code 307.
        """
        self._testRedirectDefault(307)

    def test_redirect308(self):
        """
        L{client.RedirectAgent} follows redirects on status code 308.
        """
        self._testRedirectDefault(308)

    def _sensitiveHeadersTest(
        self, expectedHostHeader: bytes = b"example.com", **crossKwargs: bool
    ) -> None:
        """
        L{client.RedirectAgent} scrubs sensitive headers when redirecting
        between differing origins.
        """
        sensitiveHeaderValues = {
            b"authorization": [b"sensitive-authnz"],
            b"cookie": [b"sensitive-cookie-data"],
            b"cookie2": [b"sensitive-cookie2-data"],
            b"proxy-authorization": [b"sensitive-proxy-auth"],
            b"wWw-auThentiCate": [b"sensitive-authn"],
            b"x-custom-sensitive": [b"sensitive-custom"],
        }
        otherHeaderValues = {b"x-random-header": [b"x-random-value"]}
        allHeaders = Headers({**sensitiveHeaderValues, **otherHeaderValues})
        redirected = self._testRedirectDefault(301, requestHeaders=allHeaders)

        def normHeaders(headers: Headers) -> dict:
            return {k.lower(): v for (k, v) in headers.getAllRawHeaders()}

        sameOriginHeaders = normHeaders(redirected.headers)
        self.assertEquals(
            sameOriginHeaders,
            {
                b"host": [b"example.com"],
                **normHeaders(allHeaders),
            },
        )

        redirectedElsewhere = self._testRedirectDefault(
            301,
            **crossKwargs,
            requestHeaders=Headers({**sensitiveHeaderValues, **otherHeaderValues}),
        )
        otherOriginHeaders = normHeaders(redirectedElsewhere.headers)
        self.assertEquals(
            otherOriginHeaders,
            {
                b"host": [expectedHostHeader],
                **normHeaders(Headers(otherHeaderValues)),
            },
        )

    def test_crossDomainHeaders(self) -> None:
        """
        L{client.RedirectAgent} scrubs sensitive headers when redirecting
        between differing domains.
        """
        self._sensitiveHeadersTest(crossDomain=True, expectedHostHeader=b"example.net")

    def test_crossPortHeaders(self) -> None:
        """
        L{client.RedirectAgent} scrubs sensitive headers when redirecting
        between differing ports.
        """
        self._sensitiveHeadersTest(
            crossPort=True, expectedHostHeader=b"example.com:8443"
        )

    def test_crossSchemeHeaders(self) -> None:
        """
        L{client.RedirectAgent} scrubs sensitive headers when redirecting
        between differing schemes.
        """
        self._sensitiveHeadersTest(crossScheme=True)

    def _testRedirectToGet(self, code, method):
        """
        L{client.RedirectAgent} changes the method to I{GET} when getting
        a redirect on a non-I{GET} request.

        @param code: HTTP status code.

        @param method: HTTP request method.
        """
        self.agent.request(method, b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers({b"location": [b"http://example.com/bar"]})
        response = Response((b"HTTP", 1, 1), code, b"OK", headers, None)
        res.callback(response)

        req2, res2 = self.protocol.requests.pop()
        self.assertEqual(b"GET", req2.method)
        self.assertEqual(b"/bar", req2.uri)

    def test_redirect303(self):
        """
        L{client.RedirectAgent} changes the method to I{GET} when getting a 303
        redirect on a I{POST} request.
        """
        self._testRedirectToGet(303, b"POST")

    def test_noLocationField(self):
        """
        If no L{Location} header field is found when getting a redirect,
        L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping a
        L{error.RedirectWithNoLocation} exception.
        """
        deferred = self.agent.request(b"GET", b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers()
        response = Response((b"HTTP", 1, 1), 301, b"OK", headers, None)
        res.callback(response)

        fail = self.failureResultOf(deferred, client.ResponseFailed)
        fail.value.reasons[0].trap(error.RedirectWithNoLocation)
        self.assertEqual(b"http://example.com/foo", fail.value.reasons[0].value.uri)
        self.assertEqual(301, fail.value.response.code)

    def _testPageRedirectFailure(self, code, method):
        """
        When getting a redirect on an unsupported request method,
        L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
        a L{error.PageRedirect} exception.

        @param code: HTTP status code.

        @param method: HTTP request method.
        """
        deferred = self.agent.request(method, b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers()
        response = Response((b"HTTP", 1, 1), code, b"OK", headers, None)
        res.callback(response)

        fail = self.failureResultOf(deferred, client.ResponseFailed)
        fail.value.reasons[0].trap(error.PageRedirect)
        self.assertEqual(
            b"http://example.com/foo", fail.value.reasons[0].value.location
        )
        self.assertEqual(code, fail.value.response.code)

    def test_307OnPost(self):
        """
        When getting a 307 redirect on a I{POST} request,
        L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
        a L{error.PageRedirect} exception.
        """
        self._testPageRedirectFailure(307, b"POST")

    def test_redirectLimit(self):
        """
        If the limit of redirects specified to L{client.RedirectAgent} is
        reached, the deferred fires with L{ResponseFailed} error wrapping
        a L{InfiniteRedirection} exception.
        """
        agent = self.buildAgentForWrapperTest(self.reactor)
        redirectAgent = client.RedirectAgent(agent, 1)

        deferred = redirectAgent.request(b"GET", b"http://example.com/foo")

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers({b"location": [b"http://example.com/bar"]})
        response = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
        res.callback(response)

        req2, res2 = self.protocol.requests.pop()

        response2 = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
        res2.callback(response2)

        fail = self.failureResultOf(deferred, client.ResponseFailed)

        fail.value.reasons[0].trap(error.InfiniteRedirection)
        self.assertEqual(
            b"http://example.com/foo", fail.value.reasons[0].value.location
        )
        self.assertEqual(302, fail.value.response.code)

    def _testRedirectURI(self, uri, location, finalURI):
        """
        When L{client.RedirectAgent} encounters a relative redirect I{URI}, it
        is resolved against the request I{URI} before following the redirect.

        @param uri: Request URI.

        @param location: I{Location} header redirect URI.

        @param finalURI: Expected final URI.
        """
        self.agent.request(b"GET", uri)

        req, res = self.protocol.requests.pop()

        headers = http_headers.Headers({b"location": [location]})
        response = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
        res.callback(response)

        req2, res2 = self.protocol.requests.pop()
        self.assertEqual(b"GET", req2.method)
        self.assertEqual(finalURI, req2.absoluteURI)

    def test_relativeURI(self):
        """
        L{client.RedirectAgent} resolves and follows relative I{URI}s in
        redirects, preserving query strings.
        """
        self._testRedirectURI(
            b"http://example.com/foo/bar", b"baz", b"http://example.com/foo/baz"
        )
        self._testRedirectURI(
            b"http://example.com/foo/bar", b"/baz", b"http://example.com/baz"
        )
        self._testRedirectURI(
            b"http://example.com/foo/bar", b"/baz?a", b"http://example.com/baz?a"
        )

    def test_relativeURIPreserveFragments(self):
        """
        L{client.RedirectAgent} resolves and follows relative I{URI}s in
        redirects, preserving fragments in way that complies with the HTTP 1.1
        bis draft.

        @see: U{https://tools.ietf.org/html/draft-ietf-httpbis-p2-semantics-22#section-7.1.2}
        """
        self._testRedirectURI(
            b"http://example.com/foo/bar#frag",
            b"/baz?a",
            b"http://example.com/baz?a#frag",
        )
        self._testRedirectURI(
            b"http://example.com/foo/bar",
            b"/baz?a#frag2",
            b"http://example.com/baz?a#frag2",
        )

    def test_relativeURISchemeRelative(self):
        """
        L{client.RedirectAgent} resolves and follows scheme relative I{URI}s in
        redirects, replacing the hostname and port when required.
        """
        self._testRedirectURI(
            b"http://example.com/foo/bar", b"//foo.com/baz", b"http://foo.com/baz"
        )
        self._testRedirectURI(
            b"http://example.com/foo/bar", b"//foo.com:81/baz", b"http://foo.com:81/baz"
        )

    def test_responseHistory(self):
        """
        L{Response.response} references the previous L{Response} from
        a redirect, or L{None} if there was no previous response.
        """
        agent = self.buildAgentForWrapperTest(self.reactor)
        redirectAgent = client.RedirectAgent(agent)

        deferred = redirectAgent.request(b"GET", b"http://example.com/foo")

        redirectReq, redirectRes = self.protocol.requests.pop()

        headers = http_headers.Headers({b"location": [b"http://example.com/bar"]})
        redirectResponse = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
        redirectRes.callback(redirectResponse)

        req, res = self.protocol.requests.pop()

        response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
        res.callback(response)

        finalResponse = self.successResultOf(deferred)
        self.assertIdentical(finalResponse.previousResponse, redirectResponse)
        self.assertIdentical(redirectResponse.previousResponse, None)


class RedirectAgentTests(
    FakeReactorAndConnectMixin,
    _RedirectAgentTestsMixin,
    AgentTestsMixin,
    runtimeTestCase,
):
    """
    Tests for L{client.RedirectAgent}.
    """

    def makeAgent(self):
        """
        @return: a new L{twisted.web.client.RedirectAgent}
        """
        return client.RedirectAgent(
            self.buildAgentForWrapperTest(self.reactor),
            sensitiveHeaderNames=[b"X-Custom-sensitive"],
        )

    def setUp(self):
        self.reactor = self.createReactor()
        self.agent = self.makeAgent()

    def test_301OnPost(self):
        """
        When getting a 301 redirect on a I{POST} request,
        L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
        a L{error.PageRedirect} exception.
        """
        self._testPageRedirectFailure(301, b"POST")

    def test_302OnPost(self):
        """
        When getting a 302 redirect on a I{POST} request,
        L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
        a L{error.PageRedirect} exception.
        """
        self._testPageRedirectFailure(302, b"POST")


class BrowserLikeRedirectAgentTests(
    FakeReactorAndConnectMixin,
    _RedirectAgentTestsMixin,
    AgentTestsMixin,
    runtimeTestCase,
):
    """
    Tests for L{client.BrowserLikeRedirectAgent}.
    """

    def makeAgent(self):
        """
        @return: a new L{twisted.web.client.BrowserLikeRedirectAgent}
        """
        return client.BrowserLikeRedirectAgent(
            self.buildAgentForWrapperTest(self.reactor),
            sensitiveHeaderNames=[b"x-Custom-sensitive"],
        )

    def setUp(self):
        self.reactor = self.createReactor()
        self.agent = self.makeAgent()

    def test_redirectToGet301(self):
        """
        L{client.BrowserLikeRedirectAgent} changes the method to I{GET} when
        getting a 302 redirect on a I{POST} request.
        """
        self._testRedirectToGet(301, b"POST")

    def test_redirectToGet302(self):
        """
        L{client.BrowserLikeRedirectAgent} changes the method to I{GET} when
        getting a 302 redirect on a I{POST} request.
        """
        self._testRedirectToGet(302, b"POST")


class AbortableStringTransport(StringTransport):
    """
    A version of L{StringTransport} that supports C{abortConnection}.
    """

    # This should be replaced by a common version in #6530.
    aborting = False

    def abortConnection(self):
        """
        A testable version of the C{ITCPTransport.abortConnection} method.

        Since this is a special case of closing the connection,
        C{loseConnection} is also called.
        """
        self.aborting = True
        self.loseConnection()


class DummyResponse:
    """
    Fake L{IResponse} for testing readBody that captures the protocol passed to
    deliverBody and uses it to make a connection with a transport.

    @ivar protocol: After C{deliverBody} is called, the protocol it was called
        with.

    @ivar transport: An instance created by calling C{transportFactory} which
        is used by L{DummyResponse.protocol} to make a connection.
    """

    code = 200
    phrase = b"OK"

    def __init__(self, headers=None, transportFactory=AbortableStringTransport):
        """
        @param headers: The headers for this response.  If L{None}, an empty
            L{Headers} instance will be used.
        @type headers: L{Headers}

        @param transportFactory: A callable used to construct the transport.
        """
        if headers is None:
            headers = Headers()
        self.headers = headers
        self.transport = transportFactory()

    def deliverBody(self, protocol):
        """
        Record the given protocol and use it to make a connection with
        L{DummyResponse.transport}.
        """
        self.protocol = protocol
        self.protocol.makeConnection(self.transport)


class AlreadyCompletedDummyResponse(DummyResponse):
    """
    A dummy response that has already had its transport closed.
    """

    def deliverBody(self, protocol):
        """
        Make the connection, then remove the transport.
        """
        self.protocol = protocol
        self.protocol.makeConnection(self.transport)
        self.protocol.transport = None


class ReadBodyTests(TestCase):
    """
    Tests for L{client.readBody}
    """

    def test_success(self):
        """
        L{client.readBody} returns a L{Deferred} which fires with the complete
        body of the L{IResponse} provider passed to it.
        """
        response = DummyResponse()
        d = client.readBody(response)
        response.protocol.dataReceived(b"first")
        response.protocol.dataReceived(b"second")
        response.protocol.connectionLost(Failure(ResponseDone()))
        self.assertEqual(self.successResultOf(d), b"firstsecond")

    def test_cancel(self):
        """
        When cancelling the L{Deferred} returned by L{client.readBody}, the
        connection to the server will be aborted.
        """
        response = DummyResponse()
        deferred = client.readBody(response)
        deferred.cancel()
        self.failureResultOf(deferred, defer.CancelledError)
        self.assertTrue(response.transport.aborting)

    def test_withPotentialDataLoss(self):
        """
        If the full body of the L{IResponse} passed to L{client.readBody} is
        not definitely received, the L{Deferred} returned by L{client.readBody}
        fires with a L{Failure} wrapping L{client.PartialDownloadError} with
        the content that was received.
        """
        response = DummyResponse()
        d = client.readBody(response)
        response.protocol.dataReceived(b"first")
        response.protocol.dataReceived(b"second")
        response.protocol.connectionLost(Failure(PotentialDataLoss()))
        failure = self.failureResultOf(d)
        failure.trap(client.PartialDownloadError)
        self.assertEqual(
            {
                "status": failure.value.status,
                "message": failure.value.message,
                "body": failure.value.response,
            },
            {
                "status": b"200",
                "message": b"OK",
                "body": b"firstsecond",
            },
        )

    def test_otherErrors(self):
        """
        If there is an exception other than L{client.PotentialDataLoss} while
        L{client.readBody} is collecting the response body, the L{Deferred}
        returned by {client.readBody} fires with that exception.
        """
        response = DummyResponse()
        d = client.readBody(response)
        response.protocol.dataReceived(b"first")
        response.protocol.connectionLost(Failure(ConnectionLost("mystery problem")))
        reason = self.failureResultOf(d)
        reason.trap(ConnectionLost)
        self.assertEqual(reason.value.args, ("mystery problem",))

    def test_deprecatedTransport(self):
        """
        Calling L{client.readBody} with a transport that does not implement
        L{twisted.internet.interfaces.ITCPTransport} produces a deprecation
        warning, but no exception when cancelling.
        """
        response = DummyResponse(transportFactory=StringTransport)
        response.transport.abortConnection = None
        d = self.assertWarns(
            DeprecationWarning,
            "Using readBody with a transport that does not have an "
            "abortConnection method",
            __file__,
            lambda: client.readBody(response),
        )
        d.cancel()
        self.failureResultOf(d, defer.CancelledError)

    def test_deprecatedTransportNoWarning(self):
        """
        Calling L{client.readBody} with a response that has already had its
        transport closed (eg. for a very small request) will not trigger a
        deprecation warning.
        """
        response = AlreadyCompletedDummyResponse()
        client.readBody(response)

        warnings = self.flushWarnings()
        self.assertEqual(len(warnings), 0)


@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
class HostnameCachingHTTPSPolicyTests(TestCase):
    def test_cacheIsUsed(self):
        """
        Verify that the connection creator is added to the
        policy's cache, and that it is reused on subsequent calls
        to creatorForNetLoc.

        """
        trustRoot = CustomOpenSSLTrustRoot()
        wrappedPolicy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
        policy = HostnameCachingHTTPSPolicy(wrappedPolicy)
        creator = policy.creatorForNetloc(b"foo", 1589)
        self.assertTrue(trustRoot.called)
        trustRoot.called = False
        self.assertEquals(1, len(policy._cache))
        connection = creator.clientConnectionForTLS(None)
        self.assertIs(trustRoot.context, connection.get_context())

        policy.creatorForNetloc(b"foo", 1589)
        self.assertFalse(trustRoot.called)

    def test_cacheRemovesOldest(self):
        """
        Verify that when the cache is full, and a new entry is added,
        the oldest entry is removed.
        """
        trustRoot = CustomOpenSSLTrustRoot()
        wrappedPolicy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
        policy = HostnameCachingHTTPSPolicy(wrappedPolicy)
        for i in range(0, 20):
            hostname = "host" + str(i)
            policy.creatorForNetloc(hostname.encode("ascii"), 8675)

        # Force host0, which was the first, to be the most recently used
        host0 = "host0"
        policy.creatorForNetloc(host0.encode("ascii"), 309)
        self.assertIn(host0, policy._cache)
        self.assertEquals(20, len(policy._cache))

        hostn = "new"
        policy.creatorForNetloc(hostn.encode("ascii"), 309)

        host1 = "host1"
        self.assertNotIn(host1, policy._cache)
        self.assertEquals(20, len(policy._cache))

        self.assertIn(hostn, policy._cache)
        self.assertIn(host0, policy._cache)

        # Accessing an item repeatedly does not corrupt the LRU.
        for _ in range(20):
            policy.creatorForNetloc(host0.encode("ascii"), 8675)

        hostNPlus1 = "new1"

        policy.creatorForNetloc(hostNPlus1.encode("ascii"), 800)

        self.assertNotIn("host2", policy._cache)
        self.assertEquals(20, len(policy._cache))

        self.assertIn(hostNPlus1, policy._cache)
        self.assertIn(hostn, policy._cache)
        self.assertIn(host0, policy._cache)

    def test_changeCacheSize(self):
        """
        Verify that changing the cache size results in a policy that
        respects the new cache size and not the default.

        """
        trustRoot = CustomOpenSSLTrustRoot()
        wrappedPolicy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
        policy = HostnameCachingHTTPSPolicy(wrappedPolicy, cacheSize=5)
        for i in range(0, 5):
            hostname = "host" + str(i)
            policy.creatorForNetloc(hostname.encode("ascii"), 8675)

        first = "host0"
        self.assertIn(first, policy._cache)
        self.assertEquals(5, len(policy._cache))

        hostn = "new"
        policy.creatorForNetloc(hostn.encode("ascii"), 309)
        self.assertNotIn(first, policy._cache)
        self.assertEquals(5, len(policy._cache))

        self.assertIn(hostn, policy._cache)


class RequestMethodInjectionTests(
    MethodInjectionTestsMixin,
    SynchronousTestCase,
):
    """
    Test L{client.Request} against HTTP method injections.
    """

    def attemptRequestWithMaliciousMethod(self, method):
        """
        Attempt a request with the provided method.

        @param method: see L{MethodInjectionTestsMixin}
        """
        client.Request(
            method=method,
            uri=b"http://twisted.invalid",
            headers=http_headers.Headers(),
            bodyProducer=None,
        )


class RequestWriteToMethodInjectionTests(
    MethodInjectionTestsMixin,
    SynchronousTestCase,
):
    """
    Test L{client.Request.writeTo} against HTTP method injections.
    """

    def attemptRequestWithMaliciousMethod(self, method):
        """
        Attempt a request with the provided method.

        @param method: see L{MethodInjectionTestsMixin}
        """
        headers = http_headers.Headers({b"Host": [b"twisted.invalid"]})
        req = client.Request(
            method=b"GET",
            uri=b"http://twisted.invalid",
            headers=headers,
            bodyProducer=None,
        )
        req.method = method
        req.writeTo(StringTransport())


class RequestURIInjectionTests(
    URIInjectionTestsMixin,
    SynchronousTestCase,
):
    """
    Test L{client.Request} against HTTP URI injections.
    """

    def attemptRequestWithMaliciousURI(self, uri):
        """
        Attempt a request with the provided URI.

        @param method: see L{URIInjectionTestsMixin}
        """
        client.Request(
            method=b"GET",
            uri=uri,
            headers=http_headers.Headers(),
            bodyProducer=None,
        )


class RequestWriteToURIInjectionTests(
    URIInjectionTestsMixin,
    SynchronousTestCase,
):
    """
    Test L{client.Request.writeTo} against HTTP method injections.
    """

    def attemptRequestWithMaliciousURI(self, uri):
        """
        Attempt a request with the provided method.

        @param method: see L{URIInjectionTestsMixin}
        """
        headers = http_headers.Headers({b"Host": [b"twisted.invalid"]})
        req = client.Request(
            method=b"GET",
            uri=b"http://twisted.invalid",
            headers=headers,
            bodyProducer=None,
        )
        req.uri = uri
        req.writeTo(StringTransport())

Zerion Mini Shell 1.0