Mini Shell
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.client.Agent} and related new client APIs.
"""
import zlib
from http.cookiejar import CookieJar
from io import BytesIO
from typing import TYPE_CHECKING, List, Optional, Tuple
from unittest import SkipTest, skipIf
from zope.interface.declarations import implementer
from zope.interface.verify import verifyObject
from incremental import Version
from twisted.internet import defer, task
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import CancelledError, Deferred, succeed
from twisted.internet.endpoints import HostnameEndpoint, TCP4ClientEndpoint
from twisted.internet.error import (
ConnectionDone,
ConnectionLost,
ConnectionRefusedError,
)
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.protocol import Factory, Protocol
from twisted.internet.task import Clock
from twisted.internet.test.test_endpoints import deterministicResolvingReactor
from twisted.logger import globalLogPublisher
from twisted.python.components import proxyForInterface
from twisted.python.deprecate import getDeprecationWarningString
from twisted.python.failure import Failure
from twisted.test.iosim import FakeTransport, IOPump
from twisted.test.proto_helpers import (
AccumulatingProtocol,
EventLoggingObserver,
MemoryReactorClock,
StringTransport,
)
from twisted.test.test_sslverify import certificatesForAuthorityAndServer
from twisted.trial.unittest import SynchronousTestCase, TestCase
from twisted.web import client, error, http_headers
from twisted.web._newclient import (
HTTP11ClientProtocol,
PotentialDataLoss,
RequestNotSent,
RequestTransmissionFailed,
Response,
ResponseFailed,
ResponseNeverReceived,
)
from twisted.web.client import (
URI,
BrowserLikePolicyForHTTPS,
FileBodyProducer,
HostnameCachingHTTPSPolicy,
HTTPConnectionPool,
Request,
ResponseDone,
_HTTP11ClientFactory,
)
from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
from twisted.web.iweb import (
UNKNOWN_LENGTH,
IAgent,
IAgentEndpointFactory,
IBodyProducer,
IPolicyForHTTPS,
IResponse,
)
from twisted.web.test.injectionhelpers import (
MethodInjectionTestsMixin,
URIInjectionTestsMixin,
)
# Creatively lie to mypy about the nature of inheritance, since dealing with
# expectations of a mixin class is basically impossible (don't use mixins).
if TYPE_CHECKING:
testMixinClass = TestCase
runtimeTestCase = object
else:
testMixinClass = object
runtimeTestCase = TestCase
try:
from twisted.internet import ssl as _ssl
except ImportError:
ssl = None
sslPresent = False
else:
ssl = _ssl
sslPresent = True
from twisted.internet._sslverify import ClientTLSOptions, IOpenSSLTrustRoot
from twisted.internet.ssl import optionsForClientTLS
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
@implementer(IOpenSSLTrustRoot)
class CustomOpenSSLTrustRoot:
called = False
context = None
def _addCACertsToContext(self, context):
self.called = True
self.context = context
class StubHTTPProtocol(Protocol):
"""
A protocol like L{HTTP11ClientProtocol} but which does not actually know
HTTP/1.1 and only collects requests in a list.
@ivar requests: A C{list} of two-tuples. Each time a request is made, a
tuple consisting of the request and the L{Deferred} returned from the
request method is appended to this list.
"""
def __init__(self) -> None:
self.requests: List[Tuple[Request, Deferred[IResponse]]] = []
self.state = "QUIESCENT"
def request(self, request):
"""
Capture the given request for later inspection.
@return: A L{Deferred} which this code will never fire.
"""
result = Deferred()
self.requests.append((request, result))
return result
class FileConsumer:
def __init__(self, outputFile):
self.outputFile = outputFile
def write(self, bytes):
self.outputFile.write(bytes)
class FileBodyProducerTests(TestCase):
"""
Tests for the L{FileBodyProducer} which reads bytes from a file and writes
them to an L{IConsumer}.
"""
def _termination(self):
"""
This method can be used as the C{terminationPredicateFactory} for a
L{Cooperator}. It returns a predicate which immediately returns
C{False}, indicating that no more work should be done this iteration.
This has the result of only allowing one iteration of a cooperative
task to be run per L{Cooperator} iteration.
"""
return lambda: True
def setUp(self):
"""
Create a L{Cooperator} hooked up to an easily controlled, deterministic
scheduler to use with L{FileBodyProducer}.
"""
self._scheduled = []
self.cooperator = task.Cooperator(self._termination, self._scheduled.append)
def test_interface(self):
"""
L{FileBodyProducer} instances provide L{IBodyProducer}.
"""
self.assertTrue(verifyObject(IBodyProducer, FileBodyProducer(BytesIO(b""))))
def test_unknownLength(self):
"""
If the L{FileBodyProducer} is constructed with a file-like object
without either a C{seek} or C{tell} method, its C{length} attribute is
set to C{UNKNOWN_LENGTH}.
"""
class HasSeek:
def seek(self, offset, whence):
pass
class HasTell:
def tell(self):
pass
producer = FileBodyProducer(HasSeek())
self.assertEqual(UNKNOWN_LENGTH, producer.length)
producer = FileBodyProducer(HasTell())
self.assertEqual(UNKNOWN_LENGTH, producer.length)
def test_knownLength(self):
"""
If the L{FileBodyProducer} is constructed with a file-like object with
both C{seek} and C{tell} methods, its C{length} attribute is set to the
size of the file as determined by those methods.
"""
inputBytes = b"here are some bytes"
inputFile = BytesIO(inputBytes)
inputFile.seek(5)
producer = FileBodyProducer(inputFile)
self.assertEqual(len(inputBytes) - 5, producer.length)
self.assertEqual(inputFile.tell(), 5)
def test_defaultCooperator(self):
"""
If no L{Cooperator} instance is passed to L{FileBodyProducer}, the
global cooperator is used.
"""
producer = FileBodyProducer(BytesIO(b""))
self.assertEqual(task.cooperate, producer._cooperate)
def test_startProducing(self):
"""
L{FileBodyProducer.startProducing} starts writing bytes from the input
file to the given L{IConsumer} and returns a L{Deferred} which fires
when they have all been written.
"""
expectedResult = b"hello, world"
readSize = 3
output = BytesIO()
consumer = FileConsumer(output)
producer = FileBodyProducer(BytesIO(expectedResult), self.cooperator, readSize)
complete = producer.startProducing(consumer)
for i in range(len(expectedResult) // readSize + 1):
self._scheduled.pop(0)()
self.assertEqual([], self._scheduled)
self.assertEqual(expectedResult, output.getvalue())
self.assertEqual(None, self.successResultOf(complete))
def test_inputClosedAtEOF(self):
"""
When L{FileBodyProducer} reaches end-of-file on the input file given to
it, the input file is closed.
"""
readSize = 4
inputBytes = b"some friendly bytes"
inputFile = BytesIO(inputBytes)
producer = FileBodyProducer(inputFile, self.cooperator, readSize)
consumer = FileConsumer(BytesIO())
producer.startProducing(consumer)
for i in range(len(inputBytes) // readSize + 2):
self._scheduled.pop(0)()
self.assertTrue(inputFile.closed)
def test_failedReadWhileProducing(self):
"""
If a read from the input file fails while producing bytes to the
consumer, the L{Deferred} returned by
L{FileBodyProducer.startProducing} fires with a L{Failure} wrapping
that exception.
"""
class BrokenFile:
def read(self, count):
raise OSError("Simulated bad thing")
producer = FileBodyProducer(BrokenFile(), self.cooperator)
complete = producer.startProducing(FileConsumer(BytesIO()))
self._scheduled.pop(0)()
self.failureResultOf(complete).trap(IOError)
def test_cancelWhileProducing(self):
"""
When the L{Deferred} returned by L{FileBodyProducer.startProducing} is
cancelled, the input file is closed and the task is stopped.
"""
expectedResult = b"hello, world"
readSize = 3
output = BytesIO()
consumer = FileConsumer(output)
inputFile = BytesIO(expectedResult)
producer = FileBodyProducer(inputFile, self.cooperator, readSize)
complete = producer.startProducing(consumer)
complete.cancel()
self.assertTrue(inputFile.closed)
self._scheduled.pop(0)()
self.assertEqual(b"", output.getvalue())
self.assertNoResult(complete)
def test_stopProducing(self):
"""
L{FileBodyProducer.stopProducing} stops the underlying L{IPullProducer}
and the cooperative task responsible for calling C{resumeProducing} and
closes the input file but does not cause the L{Deferred} returned by
C{startProducing} to fire.
"""
expectedResult = b"hello, world"
readSize = 3
output = BytesIO()
consumer = FileConsumer(output)
inputFile = BytesIO(expectedResult)
producer = FileBodyProducer(inputFile, self.cooperator, readSize)
complete = producer.startProducing(consumer)
producer.stopProducing()
self.assertTrue(inputFile.closed)
self._scheduled.pop(0)()
self.assertEqual(b"", output.getvalue())
self.assertNoResult(complete)
def test_pauseProducing(self):
"""
L{FileBodyProducer.pauseProducing} temporarily suspends writing bytes
from the input file to the given L{IConsumer}.
"""
expectedResult = b"hello, world"
readSize = 5
output = BytesIO()
consumer = FileConsumer(output)
producer = FileBodyProducer(BytesIO(expectedResult), self.cooperator, readSize)
complete = producer.startProducing(consumer)
self._scheduled.pop(0)()
self.assertEqual(output.getvalue(), expectedResult[:5])
producer.pauseProducing()
# Sort of depends on an implementation detail of Cooperator: even
# though the only task is paused, there's still a scheduled call. If
# this were to go away because Cooperator became smart enough to cancel
# this call in this case, that would be fine.
self._scheduled.pop(0)()
# Since the producer is paused, no new data should be here.
self.assertEqual(output.getvalue(), expectedResult[:5])
self.assertEqual([], self._scheduled)
self.assertNoResult(complete)
def test_resumeProducing(self):
"""
L{FileBodyProducer.resumeProducing} re-commences writing bytes from the
input file to the given L{IConsumer} after it was previously paused
with L{FileBodyProducer.pauseProducing}.
"""
expectedResult = b"hello, world"
readSize = 5
output = BytesIO()
consumer = FileConsumer(output)
producer = FileBodyProducer(BytesIO(expectedResult), self.cooperator, readSize)
producer.startProducing(consumer)
self._scheduled.pop(0)()
self.assertEqual(expectedResult[:readSize], output.getvalue())
producer.pauseProducing()
producer.resumeProducing()
self._scheduled.pop(0)()
self.assertEqual(expectedResult[: readSize * 2], output.getvalue())
def test_multipleStop(self):
"""
L{FileBodyProducer.stopProducing} can be called more than once without
raising an exception.
"""
expectedResult = b"test"
readSize = 3
output = BytesIO()
consumer = FileConsumer(output)
inputFile = BytesIO(expectedResult)
producer = FileBodyProducer(inputFile, self.cooperator, readSize)
complete = producer.startProducing(consumer)
producer.stopProducing()
producer.stopProducing()
self.assertTrue(inputFile.closed)
self._scheduled.pop(0)()
self.assertEqual(b"", output.getvalue())
self.assertNoResult(complete)
EXAMPLE_COM_IP = "127.0.0.7"
EXAMPLE_COM_V6_IP = "::7"
EXAMPLE_NET_IP = "127.0.0.8"
EXAMPLE_ORG_IP = "127.0.0.9"
FOO_LOCAL_IP = "127.0.0.10"
FOO_COM_IP = "127.0.0.11"
class FakeReactorAndConnectMixin:
"""
A test mixin providing a testable C{Reactor} class and a dummy C{connect}
method which allows instances to pretend to be endpoints.
"""
def createReactor(self):
"""
Create a L{MemoryReactorClock} and give it some hostnames it can
resolve.
@return: a L{MemoryReactorClock}-like object with a slightly limited
interface (only C{advance} and C{tcpClients} in addition to its
formally-declared reactor interfaces), which can resolve a fixed
set of domains.
"""
mrc = MemoryReactorClock()
drr = deterministicResolvingReactor(
mrc,
hostMap={
"example.com": [EXAMPLE_COM_IP],
"ipv6.example.com": [EXAMPLE_COM_V6_IP],
"example.net": [EXAMPLE_NET_IP],
"example.org": [EXAMPLE_ORG_IP],
"foo": [FOO_LOCAL_IP],
"foo.com": [FOO_COM_IP],
"127.0.0.7": ["127.0.0.7"],
"::7": ["::7"],
},
)
# Lots of tests were written expecting MemoryReactorClock and the
# reactor seen by the SUT to be the same object.
drr.tcpClients = mrc.tcpClients
drr.advance = mrc.advance
return drr
class StubEndpoint:
"""
Endpoint that wraps existing endpoint, substitutes StubHTTPProtocol, and
resulting protocol instances are attached to the given test case.
"""
def __init__(self, endpoint, testCase):
self.endpoint = endpoint
self.testCase = testCase
def nothing():
"""this function does nothing"""
self.factory = _HTTP11ClientFactory(nothing, repr(self.endpoint))
self.protocol = StubHTTPProtocol()
self.factory.buildProtocol = lambda addr: self.protocol
def connect(self, ignoredFactory):
self.testCase.protocol = self.protocol
self.endpoint.connect(self.factory)
return succeed(self.protocol)
def buildAgentForWrapperTest(self, reactor):
"""
Return an Agent suitable for use in tests that wrap the Agent and want
both a fake reactor and StubHTTPProtocol.
"""
agent = client.Agent(reactor)
_oldGetEndpoint = agent._getEndpoint
agent._getEndpoint = lambda *args: (
self.StubEndpoint(_oldGetEndpoint(*args), self)
)
return agent
def connect(self, factory):
"""
Fake implementation of an endpoint which synchronously
succeeds with an instance of L{StubHTTPProtocol} for ease of
testing.
"""
protocol = StubHTTPProtocol()
protocol.makeConnection(None)
self.protocol = protocol
return succeed(protocol)
class DummyEndpoint:
"""
An endpoint that uses a fake transport.
"""
def connect(self, factory):
protocol = factory.buildProtocol(None)
protocol.makeConnection(StringTransport())
return succeed(protocol)
class BadEndpoint:
"""
An endpoint that shouldn't be called.
"""
def connect(self, factory):
raise RuntimeError("This endpoint should not have been used.")
class DummyFactory(Factory):
"""
Create C{StubHTTPProtocol} instances.
"""
def __init__(self, quiescentCallback, metadata):
pass
protocol = StubHTTPProtocol
class HTTPConnectionPoolTests(TestCase, FakeReactorAndConnectMixin):
"""
Tests for the L{HTTPConnectionPool} class.
"""
def setUp(self):
self.fakeReactor = self.createReactor()
self.pool = HTTPConnectionPool(self.fakeReactor)
self.pool._factory = DummyFactory
# The retry code path is tested in HTTPConnectionPoolRetryTests:
self.pool.retryAutomatically = False
def test_getReturnsNewIfCacheEmpty(self):
"""
If there are no cached connections,
L{HTTPConnectionPool.getConnection} returns a new connection.
"""
self.assertEqual(self.pool._connections, {})
def gotConnection(conn):
self.assertIsInstance(conn, StubHTTPProtocol)
# The new connection is not stored in the pool:
self.assertNotIn(conn, self.pool._connections.values())
unknownKey = 12245
d = self.pool.getConnection(unknownKey, DummyEndpoint())
return d.addCallback(gotConnection)
def test_putStartsTimeout(self):
"""
If a connection is put back to the pool, a 240-sec timeout is started.
When the timeout hits, the connection is closed and removed from the
pool.
"""
# We start out with one cached connection:
protocol = StubHTTPProtocol()
protocol.makeConnection(StringTransport())
self.pool._putConnection(("http", b"example.com", 80), protocol)
# Connection is in pool, still not closed:
self.assertEqual(protocol.transport.disconnecting, False)
self.assertIn(protocol, self.pool._connections[("http", b"example.com", 80)])
# Advance 239 seconds, still not closed:
self.fakeReactor.advance(239)
self.assertEqual(protocol.transport.disconnecting, False)
self.assertIn(protocol, self.pool._connections[("http", b"example.com", 80)])
self.assertIn(protocol, self.pool._timeouts)
# Advance past 240 seconds, connection will be closed:
self.fakeReactor.advance(1.1)
self.assertEqual(protocol.transport.disconnecting, True)
self.assertNotIn(protocol, self.pool._connections[("http", b"example.com", 80)])
self.assertNotIn(protocol, self.pool._timeouts)
def test_putExceedsMaxPersistent(self):
"""
If an idle connection is put back in the cache and the max number of
persistent connections has been exceeded, one of the connections is
closed and removed from the cache.
"""
pool = self.pool
# We start out with two cached connection, the max:
origCached = [StubHTTPProtocol(), StubHTTPProtocol()]
for p in origCached:
p.makeConnection(StringTransport())
pool._putConnection(("http", b"example.com", 80), p)
self.assertEqual(pool._connections[("http", b"example.com", 80)], origCached)
timeouts = pool._timeouts.copy()
# Now we add another one:
newProtocol = StubHTTPProtocol()
newProtocol.makeConnection(StringTransport())
pool._putConnection(("http", b"example.com", 80), newProtocol)
# The oldest cached connections will be removed and disconnected:
newCached = pool._connections[("http", b"example.com", 80)]
self.assertEqual(len(newCached), 2)
self.assertEqual(newCached, [origCached[1], newProtocol])
self.assertEqual([p.transport.disconnecting for p in newCached], [False, False])
self.assertEqual(origCached[0].transport.disconnecting, True)
self.assertTrue(timeouts[origCached[0]].cancelled)
self.assertNotIn(origCached[0], pool._timeouts)
def test_maxPersistentPerHost(self):
"""
C{maxPersistentPerHost} is enforced per C{(scheme, host, port)}:
different keys have different max connections.
"""
def addProtocol(scheme, host, port):
p = StubHTTPProtocol()
p.makeConnection(StringTransport())
self.pool._putConnection((scheme, host, port), p)
return p
persistent = []
persistent.append(addProtocol("http", b"example.com", 80))
persistent.append(addProtocol("http", b"example.com", 80))
addProtocol("https", b"example.com", 443)
addProtocol("http", b"www2.example.com", 80)
self.assertEqual(
self.pool._connections[("http", b"example.com", 80)], persistent
)
self.assertEqual(len(self.pool._connections[("https", b"example.com", 443)]), 1)
self.assertEqual(
len(self.pool._connections[("http", b"www2.example.com", 80)]), 1
)
def test_getCachedConnection(self):
"""
Getting an address which has a cached connection returns the cached
connection, removes it from the cache and cancels its timeout.
"""
# We start out with one cached connection:
protocol = StubHTTPProtocol()
protocol.makeConnection(StringTransport())
self.pool._putConnection(("http", b"example.com", 80), protocol)
def gotConnection(conn):
# We got the cached connection:
self.assertIdentical(protocol, conn)
self.assertNotIn(conn, self.pool._connections[("http", b"example.com", 80)])
# And the timeout was cancelled:
self.fakeReactor.advance(241)
self.assertEqual(conn.transport.disconnecting, False)
self.assertNotIn(conn, self.pool._timeouts)
return self.pool.getConnection(
("http", b"example.com", 80),
BadEndpoint(),
).addCallback(gotConnection)
def test_newConnection(self):
"""
The pool's C{_newConnection} method constructs a new connection.
"""
# We start out with one cached connection:
protocol = StubHTTPProtocol()
protocol.makeConnection(StringTransport())
key = 12245
self.pool._putConnection(key, protocol)
def gotConnection(newConnection):
# We got a new connection:
self.assertNotIdentical(protocol, newConnection)
# And the old connection is still there:
self.assertIn(protocol, self.pool._connections[key])
# While the new connection is not:
self.assertNotIn(newConnection, self.pool._connections.values())
d = self.pool._newConnection(key, DummyEndpoint())
return d.addCallback(gotConnection)
def test_getSkipsDisconnected(self):
"""
When getting connections out of the cache, disconnected connections
are removed and not returned.
"""
pool = self.pool
key = ("http", b"example.com", 80)
# We start out with two cached connection, the max:
origCached = [StubHTTPProtocol(), StubHTTPProtocol()]
for p in origCached:
p.makeConnection(StringTransport())
pool._putConnection(key, p)
self.assertEqual(pool._connections[key], origCached)
# We close the first one:
origCached[0].state = "DISCONNECTED"
# Now, when we retrive connections we should get the *second* one:
result = []
self.pool.getConnection(key, BadEndpoint()).addCallback(result.append)
self.assertIdentical(result[0], origCached[1])
# And both the disconnected and removed connections should be out of
# the cache:
self.assertEqual(pool._connections[key], [])
self.assertEqual(pool._timeouts, {})
def test_putNotQuiescent(self):
"""
If a non-quiescent connection is put back in the cache, an error is
logged.
"""
protocol = StubHTTPProtocol()
# By default state is QUIESCENT
self.assertEqual(protocol.state, "QUIESCENT")
logObserver = EventLoggingObserver.createWithCleanup(self, globalLogPublisher)
protocol.state = "NOTQUIESCENT"
self.pool._putConnection(("http", b"example.com", 80), protocol)
self.assertEquals(1, len(logObserver))
event = logObserver[0]
f = event["log_failure"]
self.assertIsInstance(f.value, RuntimeError)
self.assertEqual(
f.getErrorMessage(), "BUG: Non-quiescent protocol added to connection pool."
)
self.assertIdentical(
None, self.pool._connections.get(("http", b"example.com", 80))
)
self.flushLoggedErrors(RuntimeError)
def test_getUsesQuiescentCallback(self):
"""
When L{HTTPConnectionPool.getConnection} connects, it returns a
C{Deferred} that fires with an instance of L{HTTP11ClientProtocol}
that has the correct quiescent callback attached. When this callback
is called the protocol is returned to the cache correctly, using the
right key.
"""
class StringEndpoint:
def connect(self, factory):
p = factory.buildProtocol(None)
p.makeConnection(StringTransport())
return succeed(p)
pool = HTTPConnectionPool(self.fakeReactor, True)
pool.retryAutomatically = False
result = []
key = "a key"
pool.getConnection(key, StringEndpoint()).addCallback(result.append)
protocol = result[0]
self.assertIsInstance(protocol, HTTP11ClientProtocol)
# Now that we have protocol instance, lets try to put it back in the
# pool:
protocol._state = "QUIESCENT"
protocol._quiescentCallback(protocol)
# If we try to retrive a connection to same destination again, we
# should get the same protocol, because it should've been added back
# to the pool:
result2 = []
pool.getConnection(key, StringEndpoint()).addCallback(result2.append)
self.assertIdentical(result2[0], protocol)
def test_closeCachedConnections(self):
"""
L{HTTPConnectionPool.closeCachedConnections} closes all cached
connections and removes them from the cache. It returns a Deferred
that fires when they have all lost their connections.
"""
persistent = []
def addProtocol(scheme, host, port):
p = HTTP11ClientProtocol()
p.makeConnection(StringTransport())
self.pool._putConnection((scheme, host, port), p)
persistent.append(p)
addProtocol("http", b"example.com", 80)
addProtocol("http", b"www2.example.com", 80)
doneDeferred = self.pool.closeCachedConnections()
# Connections have begun disconnecting:
for p in persistent:
self.assertEqual(p.transport.disconnecting, True)
self.assertEqual(self.pool._connections, {})
# All timeouts were cancelled and removed:
for dc in self.fakeReactor.getDelayedCalls():
self.assertEqual(dc.cancelled, True)
self.assertEqual(self.pool._timeouts, {})
# Returned Deferred fires when all connections have been closed:
result = []
doneDeferred.addCallback(result.append)
self.assertEqual(result, [])
persistent[0].connectionLost(Failure(ConnectionDone()))
self.assertEqual(result, [])
persistent[1].connectionLost(Failure(ConnectionDone()))
self.assertEqual(result, [None])
def test_cancelGetConnectionCancelsEndpointConnect(self):
"""
Cancelling the C{Deferred} returned from
L{HTTPConnectionPool.getConnection} cancels the C{Deferred} returned
by opening a new connection with the given endpoint.
"""
self.assertEqual(self.pool._connections, {})
connectionResult = Deferred()
class Endpoint:
def connect(self, factory):
return connectionResult
d = self.pool.getConnection(12345, Endpoint())
d.cancel()
self.assertEqual(self.failureResultOf(connectionResult).type, CancelledError)
class AgentTestsMixin:
"""
Tests for any L{IAgent} implementation.
"""
def test_interface(self):
"""
The agent object provides L{IAgent}.
"""
self.assertTrue(verifyObject(IAgent, self.makeAgent()))
class IntegrationTestingMixin:
"""
Transport-to-Agent integration tests for both HTTP and HTTPS.
"""
def test_integrationTestIPv4(self):
"""
L{Agent} works over IPv4.
"""
self.integrationTest(b"example.com", EXAMPLE_COM_IP, IPv4Address)
def test_integrationTestIPv4Address(self):
"""
L{Agent} works over IPv4 when hostname is an IPv4 address.
"""
self.integrationTest(b"127.0.0.7", "127.0.0.7", IPv4Address)
def test_integrationTestIPv6(self):
"""
L{Agent} works over IPv6.
"""
self.integrationTest(b"ipv6.example.com", EXAMPLE_COM_V6_IP, IPv6Address)
def test_integrationTestIPv6Address(self):
"""
L{Agent} works over IPv6 when hostname is an IPv6 address.
"""
self.integrationTest(b"[::7]", "::7", IPv6Address)
def integrationTest(
self,
hostName,
expectedAddress,
addressType,
serverWrapper=lambda server: server,
createAgent=client.Agent,
scheme=b"http",
):
"""
L{Agent} will make a TCP connection, send an HTTP request, and return a
L{Deferred} that fires when the response has been received.
@param hostName: The hostname to interpolate into the URL to be
requested.
@type hostName: L{bytes}
@param expectedAddress: The expected address string.
@type expectedAddress: L{bytes}
@param addressType: The class to construct an address out of.
@type addressType: L{type}
@param serverWrapper: A callable that takes a protocol factory and
returns a protocol factory; used to wrap the server / responder
side in a TLS server.
@type serverWrapper:
serverWrapper(L{twisted.internet.interfaces.IProtocolFactory}) ->
L{twisted.internet.interfaces.IProtocolFactory}
@param createAgent: A callable that takes a reactor and produces an
L{IAgent}; used to construct an agent with an appropriate trust
root for TLS.
@type createAgent: createAgent(reactor) -> L{IAgent}
@param scheme: The scheme to test, C{http} or C{https}
@type scheme: L{bytes}
"""
reactor = self.createReactor()
agent = createAgent(reactor)
deferred = agent.request(b"GET", scheme + b"://" + hostName + b"/")
host, port, factory, timeout, bind = reactor.tcpClients[0]
self.assertEqual(host, expectedAddress)
peerAddress = addressType("TCP", host, port)
clientProtocol = factory.buildProtocol(peerAddress)
clientTransport = FakeTransport(clientProtocol, False, peerAddress=peerAddress)
clientProtocol.makeConnection(clientTransport)
@Factory.forProtocol
def accumulator():
ap = AccumulatingProtocol()
accumulator.currentProtocol = ap
return ap
accumulator.currentProtocol = None
accumulator.protocolConnectionMade = None
wrapper = serverWrapper(accumulator).buildProtocol(None)
serverTransport = FakeTransport(wrapper, True)
wrapper.makeConnection(serverTransport)
pump = IOPump(clientProtocol, wrapper, clientTransport, serverTransport, False)
pump.flush()
self.assertNoResult(deferred)
lines = accumulator.currentProtocol.data.split(b"\r\n")
self.assertTrue(lines[0].startswith(b"GET / HTTP"), lines[0])
headers = dict([line.split(b": ", 1) for line in lines[1:] if line])
self.assertEqual(headers[b"Host"], hostName)
self.assertNoResult(deferred)
accumulator.currentProtocol.transport.write(
b"HTTP/1.1 200 OK"
b"\r\nX-An-Header: an-value\r\n"
b"\r\nContent-length: 12\r\n\r\n"
b"hello world!"
)
pump.flush()
response = self.successResultOf(deferred)
self.assertEquals(
response.headers.getRawHeaders(b"x-an-header")[0], b"an-value"
)
@implementer(IAgentEndpointFactory)
class StubEndpointFactory:
"""
A stub L{IAgentEndpointFactory} for use in testing.
"""
def endpointForURI(self, uri):
"""
Testing implementation.
@param uri: A L{URI}.
@return: C{(scheme, host, port)} of passed in URI; violation of
interface but useful for testing.
@rtype: L{tuple}
"""
return (uri.scheme, uri.host, uri.port)
class AgentTests(
TestCase, FakeReactorAndConnectMixin, AgentTestsMixin, IntegrationTestingMixin
):
"""
Tests for the new HTTP client API provided by L{Agent}.
"""
def makeAgent(self):
"""
@return: a new L{twisted.web.client.Agent} instance
"""
return client.Agent(self.reactor)
def setUp(self):
"""
Create an L{Agent} wrapped around a fake reactor.
"""
self.reactor = self.createReactor()
self.agent = self.makeAgent()
def test_defaultPool(self):
"""
If no pool is passed in, the L{Agent} creates a non-persistent pool.
"""
agent = client.Agent(self.reactor)
self.assertIsInstance(agent._pool, HTTPConnectionPool)
self.assertEqual(agent._pool.persistent, False)
self.assertIdentical(agent._reactor, agent._pool._reactor)
def test_persistent(self):
"""
If C{persistent} is set to C{True} on the L{HTTPConnectionPool} (the
default), C{Request}s are created with their C{persistent} flag set to
C{True}.
"""
pool = HTTPConnectionPool(self.reactor)
agent = client.Agent(self.reactor, pool=pool)
agent._getEndpoint = lambda *args: self
agent.request(b"GET", b"http://127.0.0.1")
self.assertEqual(self.protocol.requests[0][0].persistent, True)
def test_nonPersistent(self):
"""
If C{persistent} is set to C{False} when creating the
L{HTTPConnectionPool}, C{Request}s are created with their
C{persistent} flag set to C{False}.
Elsewhere in the tests for the underlying HTTP code we ensure that
this will result in the disconnection of the HTTP protocol once the
request is done, so that the connection will not be returned to the
pool.
"""
pool = HTTPConnectionPool(self.reactor, persistent=False)
agent = client.Agent(self.reactor, pool=pool)
agent._getEndpoint = lambda *args: self
agent.request(b"GET", b"http://127.0.0.1")
self.assertEqual(self.protocol.requests[0][0].persistent, False)
def test_connectUsesConnectionPool(self):
"""
When a connection is made by the Agent, it uses its pool's
C{getConnection} method to do so, with the endpoint returned by
C{self._getEndpoint}. The key used is C{(scheme, host, port)}.
"""
endpoint = DummyEndpoint()
class MyAgent(client.Agent):
def _getEndpoint(this, uri):
self.assertEqual(
(uri.scheme, uri.host, uri.port), (b"http", b"foo", 80)
)
return endpoint
class DummyPool:
connected = False
persistent = False
def getConnection(this, key, ep):
this.connected = True
self.assertEqual(ep, endpoint)
# This is the key the default Agent uses, others will have
# different keys:
self.assertEqual(key, (b"http", b"foo", 80))
return defer.succeed(StubHTTPProtocol())
pool = DummyPool()
agent = MyAgent(self.reactor, pool=pool)
self.assertIdentical(pool, agent._pool)
headers = http_headers.Headers()
headers.addRawHeader(b"host", b"foo")
bodyProducer = object()
agent.request(
b"GET", b"http://foo/", bodyProducer=bodyProducer, headers=headers
)
self.assertEqual(agent._pool.connected, True)
def test_nonBytesMethod(self):
"""
L{Agent.request} raises L{TypeError} when the C{method} argument isn't
L{bytes}.
"""
self.assertRaises(TypeError, self.agent.request, "GET", b"http://foo.example/")
def test_unsupportedScheme(self):
"""
L{Agent.request} returns a L{Deferred} which fails with
L{SchemeNotSupported} if the scheme of the URI passed to it is not
C{'http'}.
"""
return self.assertFailure(
self.agent.request(b"GET", b"mailto:alice@example.com"), SchemeNotSupported
)
def test_connectionFailed(self):
"""
The L{Deferred} returned by L{Agent.request} fires with a L{Failure} if
the TCP connection attempt fails.
"""
result = self.agent.request(b"GET", b"http://foo/")
# Cause the connection to be refused
host, port, factory = self.reactor.tcpClients.pop()[:3]
factory.clientConnectionFailed(None, Failure(ConnectionRefusedError()))
self.reactor.advance(10)
# ^ https://twistedmatrix.com/trac/ticket/8202
self.failureResultOf(result, ConnectionRefusedError)
def test_connectHTTP(self):
"""
L{Agent._getEndpoint} return a C{HostnameEndpoint} when passed a scheme
of C{'http'}.
"""
expectedHost = b"example.com"
expectedPort = 1234
endpoint = self.agent._getEndpoint(
URI.fromBytes(b"http://%b:%d" % (expectedHost, expectedPort))
)
self.assertEqual(endpoint._hostStr, "example.com")
self.assertEqual(endpoint._port, expectedPort)
self.assertIsInstance(endpoint, HostnameEndpoint)
def test_nonDecodableURI(self):
"""
L{Agent._getEndpoint} when given a non-ASCII decodable URI will raise a
L{ValueError} saying such.
"""
uri = URI.fromBytes(b"http://example.com:80")
uri.host = "\u2603.com".encode()
with self.assertRaises(ValueError) as e:
self.agent._getEndpoint(uri)
self.assertEqual(
e.exception.args[0],
(
"The host of the provided URI ({reprout}) contains "
"non-ASCII octets, it should be ASCII "
"decodable."
).format(reprout=repr(uri.host)),
)
def test_hostProvided(self):
"""
If L{None} is passed to L{Agent.request} for the C{headers} parameter,
a L{Headers} instance is created for the request and a I{Host} header
added to it.
"""
self.agent._getEndpoint = lambda *args: self
self.agent.request(b"GET", b"http://example.com/foo?bar")
req, res = self.protocol.requests.pop()
self.assertEqual(req.headers.getRawHeaders(b"host"), [b"example.com"])
def test_hostIPv6Bracketed(self):
"""
If an IPv6 address is used in the C{uri} passed to L{Agent.request},
the computed I{Host} header needs to be bracketed.
"""
self.agent._getEndpoint = lambda *args: self
self.agent.request(b"GET", b"http://[::1]/")
req, res = self.protocol.requests.pop()
self.assertEqual(req.headers.getRawHeaders(b"host"), [b"[::1]"])
def test_hostOverride(self):
"""
If the headers passed to L{Agent.request} includes a value for the
I{Host} header, that value takes precedence over the one which would
otherwise be automatically provided.
"""
headers = http_headers.Headers({b"foo": [b"bar"], b"host": [b"quux"]})
self.agent._getEndpoint = lambda *args: self
self.agent.request(b"GET", b"http://example.com/foo?bar", headers)
req, res = self.protocol.requests.pop()
self.assertEqual(req.headers.getRawHeaders(b"host"), [b"quux"])
def test_headersUnmodified(self):
"""
If a I{Host} header must be added to the request, the L{Headers}
instance passed to L{Agent.request} is not modified.
"""
headers = http_headers.Headers()
self.agent._getEndpoint = lambda *args: self
self.agent.request(b"GET", b"http://example.com/foo", headers)
protocol = self.protocol
# The request should have been issued.
self.assertEqual(len(protocol.requests), 1)
# And the headers object passed in should not have changed.
self.assertEqual(headers, http_headers.Headers())
def test_hostValueStandardHTTP(self):
"""
When passed a scheme of C{'http'} and a port of C{80},
L{Agent._computeHostValue} returns a string giving just
the host name passed to it.
"""
self.assertEqual(
self.agent._computeHostValue(b"http", b"example.com", 80), b"example.com"
)
def test_hostValueNonStandardHTTP(self):
"""
When passed a scheme of C{'http'} and a port other than C{80},
L{Agent._computeHostValue} returns a string giving the
host passed to it joined together with the port number by C{":"}.
"""
self.assertEqual(
self.agent._computeHostValue(b"http", b"example.com", 54321),
b"example.com:54321",
)
def test_hostValueStandardHTTPS(self):
"""
When passed a scheme of C{'https'} and a port of C{443},
L{Agent._computeHostValue} returns a string giving just
the host name passed to it.
"""
self.assertEqual(
self.agent._computeHostValue(b"https", b"example.com", 443), b"example.com"
)
def test_hostValueNonStandardHTTPS(self):
"""
When passed a scheme of C{'https'} and a port other than C{443},
L{Agent._computeHostValue} returns a string giving the
host passed to it joined together with the port number by C{":"}.
"""
self.assertEqual(
self.agent._computeHostValue(b"https", b"example.com", 54321),
b"example.com:54321",
)
def test_request(self):
"""
L{Agent.request} establishes a new connection to the host indicated by
the host part of the URI passed to it and issues a request using the
method, the path portion of the URI, the headers, and the body producer
passed to it. It returns a L{Deferred} which fires with an
L{IResponse} from the server.
"""
self.agent._getEndpoint = lambda *args: self
headers = http_headers.Headers({b"foo": [b"bar"]})
# Just going to check the body for identity, so it doesn't need to be
# real.
body = object()
self.agent.request(b"GET", b"http://example.com:1234/foo?bar", headers, body)
protocol = self.protocol
# The request should be issued.
self.assertEqual(len(protocol.requests), 1)
req, res = protocol.requests.pop()
self.assertIsInstance(req, Request)
self.assertEqual(req.method, b"GET")
self.assertEqual(req.uri, b"/foo?bar")
self.assertEqual(
req.headers,
http_headers.Headers({b"foo": [b"bar"], b"host": [b"example.com:1234"]}),
)
self.assertIdentical(req.bodyProducer, body)
def test_connectTimeout(self):
"""
L{Agent} takes a C{connectTimeout} argument which is forwarded to the
following C{connectTCP} agent.
"""
agent = client.Agent(self.reactor, connectTimeout=5)
agent.request(b"GET", b"http://foo/")
timeout = self.reactor.tcpClients.pop()[3]
self.assertEqual(5, timeout)
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
def test_connectTimeoutHTTPS(self):
"""
L{Agent} takes a C{connectTimeout} argument which is forwarded to the
following C{connectTCP} call.
"""
agent = client.Agent(self.reactor, connectTimeout=5)
agent.request(b"GET", b"https://foo/")
timeout = self.reactor.tcpClients.pop()[3]
self.assertEqual(5, timeout)
def test_bindAddress(self):
"""
L{Agent} takes a C{bindAddress} argument which is forwarded to the
following C{connectTCP} call.
"""
agent = client.Agent(self.reactor, bindAddress="192.168.0.1")
agent.request(b"GET", b"http://foo/")
address = self.reactor.tcpClients.pop()[4]
self.assertEqual("192.168.0.1", address)
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
def test_bindAddressSSL(self):
"""
L{Agent} takes a C{bindAddress} argument which is forwarded to the
following C{connectSSL} call.
"""
agent = client.Agent(self.reactor, bindAddress="192.168.0.1")
agent.request(b"GET", b"https://foo/")
address = self.reactor.tcpClients.pop()[4]
self.assertEqual("192.168.0.1", address)
def test_responseIncludesRequest(self):
"""
L{Response}s returned by L{Agent.request} have a reference to the
L{Request} that was originally issued.
"""
uri = b"http://example.com/"
agent = self.buildAgentForWrapperTest(self.reactor)
d = agent.request(b"GET", uri)
# The request should be issued.
self.assertEqual(len(self.protocol.requests), 1)
req, res = self.protocol.requests.pop()
self.assertIsInstance(req, Request)
resp = client.Response._construct(
(b"HTTP", 1, 1), 200, b"OK", client.Headers({}), None, req
)
res.callback(resp)
response = self.successResultOf(d)
self.assertEqual(
(
response.request.method,
response.request.absoluteURI,
response.request.headers,
),
(req.method, req.absoluteURI, req.headers),
)
def test_requestAbsoluteURI(self):
"""
L{Request.absoluteURI} is the absolute URI of the request.
"""
uri = b"http://example.com/foo;1234?bar#frag"
agent = self.buildAgentForWrapperTest(self.reactor)
agent.request(b"GET", uri)
# The request should be issued.
self.assertEqual(len(self.protocol.requests), 1)
req, res = self.protocol.requests.pop()
self.assertIsInstance(req, Request)
self.assertEqual(req.absoluteURI, uri)
def test_requestMissingAbsoluteURI(self):
"""
L{Request.absoluteURI} is L{None} if L{Request._parsedURI} is L{None}.
"""
request = client.Request(b"FOO", b"/", client.Headers(), None)
self.assertIdentical(request.absoluteURI, None)
def test_endpointFactory(self):
"""
L{Agent.usingEndpointFactory} creates an L{Agent} that uses the given
factory to create endpoints.
"""
factory = StubEndpointFactory()
agent = client.Agent.usingEndpointFactory(None, endpointFactory=factory)
uri = URI.fromBytes(b"http://example.com/")
returnedEndpoint = agent._getEndpoint(uri)
self.assertEqual(returnedEndpoint, (b"http", b"example.com", 80))
def test_endpointFactoryDefaultPool(self):
"""
If no pool is passed in to L{Agent.usingEndpointFactory}, a default
pool is constructed with no persistent connections.
"""
agent = client.Agent.usingEndpointFactory(self.reactor, StubEndpointFactory())
pool = agent._pool
self.assertEqual(
(pool.__class__, pool.persistent, pool._reactor),
(HTTPConnectionPool, False, agent._reactor),
)
def test_endpointFactoryPool(self):
"""
If a pool is passed in to L{Agent.usingEndpointFactory} it is used as
the L{Agent} pool.
"""
pool = object()
agent = client.Agent.usingEndpointFactory(
self.reactor, StubEndpointFactory(), pool
)
self.assertIs(pool, agent._pool)
class AgentMethodInjectionTests(
FakeReactorAndConnectMixin,
MethodInjectionTestsMixin,
SynchronousTestCase,
):
"""
Test L{client.Agent} against HTTP method injections.
"""
def attemptRequestWithMaliciousMethod(self, method):
"""
Attempt a request with the provided method.
@param method: see L{MethodInjectionTestsMixin}
"""
agent = client.Agent(self.createReactor())
uri = b"http://twisted.invalid"
agent.request(method, uri, client.Headers(), None)
class AgentURIInjectionTests(
FakeReactorAndConnectMixin,
URIInjectionTestsMixin,
SynchronousTestCase,
):
"""
Test L{client.Agent} against URI injections.
"""
def attemptRequestWithMaliciousURI(self, uri):
"""
Attempt a request with the provided method.
@param uri: see L{URIInjectionTestsMixin}
"""
agent = client.Agent(self.createReactor())
method = b"GET"
agent.request(method, uri, client.Headers(), None)
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
class AgentHTTPSTests(TestCase, FakeReactorAndConnectMixin, IntegrationTestingMixin):
"""
Tests for the new HTTP client API that depends on SSL.
"""
def makeEndpoint(self, host=b"example.com", port=443):
"""
Create an L{Agent} with an https scheme and return its endpoint
created according to the arguments.
@param host: The host for the endpoint.
@type host: L{bytes}
@param port: The port for the endpoint.
@type port: L{int}
@return: An endpoint of an L{Agent} constructed according to args.
@rtype: L{SSL4ClientEndpoint}
"""
return client.Agent(self.createReactor())._getEndpoint(
URI.fromBytes(b"https://%b:%d/" % (host, port))
)
def test_endpointType(self):
"""
L{Agent._getEndpoint} return a L{SSL4ClientEndpoint} when passed a
scheme of C{'https'}.
"""
from twisted.internet.endpoints import _WrapperEndpoint
endpoint = self.makeEndpoint()
self.assertIsInstance(endpoint, _WrapperEndpoint)
self.assertIsInstance(endpoint._wrappedEndpoint, HostnameEndpoint)
def test_hostArgumentIsRespected(self):
"""
If a host is passed, the endpoint respects it.
"""
endpoint = self.makeEndpoint(host=b"example.com")
self.assertEqual(endpoint._wrappedEndpoint._hostStr, "example.com")
def test_portArgumentIsRespected(self):
"""
If a port is passed, the endpoint respects it.
"""
expectedPort = 4321
endpoint = self.makeEndpoint(port=expectedPort)
self.assertEqual(endpoint._wrappedEndpoint._port, expectedPort)
def test_contextFactoryType(self):
"""
L{Agent} wraps its connection creator creator and uses modern TLS APIs.
"""
endpoint = self.makeEndpoint()
contextFactory = endpoint._wrapperFactory(None)._connectionCreator
self.assertIsInstance(contextFactory, ClientTLSOptions)
self.assertEqual(contextFactory._hostname, "example.com")
def test_connectHTTPSCustomConnectionCreator(self):
"""
If a custom L{WebClientConnectionCreator}-like object is passed to
L{Agent.__init__} it will be used to determine the SSL parameters for
HTTPS requests. When an HTTPS request is made, the hostname and port
number of the request URL will be passed to the connection creator's
C{creatorForNetloc} method. The resulting context object will be used
to establish the SSL connection.
"""
expectedHost = b"example.org"
expectedPort = 20443
class JustEnoughConnection:
handshakeStarted = False
connectState = False
def do_handshake(self):
"""
The handshake started. Record that fact.
"""
self.handshakeStarted = True
def set_connect_state(self):
"""
The connection started. Record that fact.
"""
self.connectState = True
contextArgs = []
@implementer(IOpenSSLClientConnectionCreator)
class JustEnoughCreator:
def __init__(self, hostname, port):
self.hostname = hostname
self.port = port
def clientConnectionForTLS(self, tlsProtocol):
"""
Implement L{IOpenSSLClientConnectionCreator}.
@param tlsProtocol: The TLS protocol.
@type tlsProtocol: L{TLSMemoryBIOProtocol}
@return: C{expectedConnection}
"""
contextArgs.append((tlsProtocol, self.hostname, self.port))
return expectedConnection
expectedConnection = JustEnoughConnection()
@implementer(IPolicyForHTTPS)
class StubBrowserLikePolicyForHTTPS:
def creatorForNetloc(self, hostname, port):
"""
Emulate L{BrowserLikePolicyForHTTPS}.
@param hostname: The hostname to verify.
@type hostname: L{bytes}
@param port: The port number.
@type port: L{int}
@return: a stub L{IOpenSSLClientConnectionCreator}
@rtype: L{JustEnoughCreator}
"""
return JustEnoughCreator(hostname, port)
expectedCreatorCreator = StubBrowserLikePolicyForHTTPS()
reactor = self.createReactor()
agent = client.Agent(reactor, expectedCreatorCreator)
endpoint = agent._getEndpoint(
URI.fromBytes(b"https://%b:%d" % (expectedHost, expectedPort))
)
endpoint.connect(Factory.forProtocol(Protocol))
tlsFactory = reactor.tcpClients[-1][2]
tlsProtocol = tlsFactory.buildProtocol(None)
tlsProtocol.makeConnection(StringTransport())
tls = contextArgs[0][0]
self.assertIsInstance(tls, TLSMemoryBIOProtocol)
self.assertEqual(contextArgs[0][1:], (expectedHost, expectedPort))
self.assertTrue(expectedConnection.handshakeStarted)
self.assertTrue(expectedConnection.connectState)
def test_deprecatedDuckPolicy(self):
"""
Passing something that duck-types I{like} a L{web client context
factory <twisted.web.client.WebClientContextFactory>} - something that
does not provide L{IPolicyForHTTPS} - to L{Agent} emits a
L{DeprecationWarning} even if you don't actually C{import
WebClientContextFactory} to do it.
"""
def warnMe():
client.Agent(
deterministicResolvingReactor(MemoryReactorClock()),
"does-not-provide-IPolicyForHTTPS",
)
warnMe()
warnings = self.flushWarnings([warnMe])
self.assertEqual(len(warnings), 1)
[warning] = warnings
self.assertEqual(warning["category"], DeprecationWarning)
self.assertEqual(
warning["message"],
"'does-not-provide-IPolicyForHTTPS' was passed as the HTTPS "
"policy for an Agent, but it does not provide IPolicyForHTTPS. "
"Since Twisted 14.0, you must pass a provider of IPolicyForHTTPS.",
)
def test_alternateTrustRoot(self):
"""
L{BrowserLikePolicyForHTTPS.creatorForNetloc} returns an
L{IOpenSSLClientConnectionCreator} provider which will add certificates
from the given trust root.
"""
trustRoot = CustomOpenSSLTrustRoot()
policy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
creator = policy.creatorForNetloc(b"thingy", 4321)
self.assertTrue(trustRoot.called)
connection = creator.clientConnectionForTLS(None)
self.assertIs(trustRoot.context, connection.get_context())
def integrationTest(self, hostName, expectedAddress, addressType):
"""
Wrap L{AgentTestsMixin.integrationTest} with TLS.
"""
certHostName = hostName.strip(b"[]")
authority, server = certificatesForAuthorityAndServer(
certHostName.decode("ascii")
)
def tlsify(serverFactory):
return TLSMemoryBIOFactory(server.options(), False, serverFactory)
def tlsagent(reactor):
from zope.interface import implementer
from twisted.web.iweb import IPolicyForHTTPS
@implementer(IPolicyForHTTPS)
class Policy:
def creatorForNetloc(self, hostname, port):
return optionsForClientTLS(
hostname.decode("ascii"), trustRoot=authority
)
return client.Agent(reactor, contextFactory=Policy())
(
super().integrationTest(
hostName,
expectedAddress,
addressType,
serverWrapper=tlsify,
createAgent=tlsagent,
scheme=b"https",
)
)
class WebClientContextFactoryTests(TestCase):
"""
Tests for the context factory wrapper for web clients
L{twisted.web.client.WebClientContextFactory}.
"""
def setUp(self):
"""
Get WebClientContextFactory while quashing its deprecation warning.
"""
from twisted.web.client import WebClientContextFactory
self.warned = self.flushWarnings([WebClientContextFactoryTests.setUp])
self.webClientContextFactory = WebClientContextFactory
def test_deprecated(self):
"""
L{twisted.web.client.WebClientContextFactory} is deprecated. Importing
it displays a warning.
"""
self.assertEqual(len(self.warned), 1)
[warning] = self.warned
self.assertEqual(warning["category"], DeprecationWarning)
self.assertEqual(
warning["message"],
getDeprecationWarningString(
self.webClientContextFactory,
Version("Twisted", 14, 0, 0),
replacement=BrowserLikePolicyForHTTPS,
)
# See https://twistedmatrix.com/trac/ticket/7242
.replace(";", ":"),
)
@skipIf(sslPresent, "SSL Present.")
def test_missingSSL(self):
"""
If C{getContext} is called and SSL is not available, raise
L{NotImplementedError}.
"""
self.assertRaises(
NotImplementedError,
self.webClientContextFactory().getContext,
b"example.com",
443,
)
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
def test_returnsContext(self):
"""
If SSL is present, C{getContext} returns a L{OpenSSL.SSL.Context}.
"""
ctx = self.webClientContextFactory().getContext("example.com", 443)
self.assertIsInstance(ctx, ssl.SSL.Context)
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
def test_setsTrustRootOnContextToDefaultTrustRoot(self):
"""
The L{CertificateOptions} has C{trustRoot} set to the default trust
roots.
"""
ctx = self.webClientContextFactory()
certificateOptions = ctx._getCertificateOptions("example.com", 443)
self.assertIsInstance(certificateOptions.trustRoot, ssl.OpenSSLDefaultPaths)
class HTTPConnectionPoolRetryTests(TestCase, FakeReactorAndConnectMixin):
"""
L{client.HTTPConnectionPool}, by using
L{client._RetryingHTTP11ClientProtocol}, supports retrying requests done
against previously cached connections.
"""
def test_onlyRetryIdempotentMethods(self):
"""
Only GET, HEAD, OPTIONS, TRACE, DELETE methods cause a retry.
"""
pool = client.HTTPConnectionPool(None)
connection = client._RetryingHTTP11ClientProtocol(None, pool)
self.assertTrue(connection._shouldRetry(b"GET", RequestNotSent(), None))
self.assertTrue(connection._shouldRetry(b"HEAD", RequestNotSent(), None))
self.assertTrue(connection._shouldRetry(b"OPTIONS", RequestNotSent(), None))
self.assertTrue(connection._shouldRetry(b"TRACE", RequestNotSent(), None))
self.assertTrue(connection._shouldRetry(b"DELETE", RequestNotSent(), None))
self.assertFalse(connection._shouldRetry(b"POST", RequestNotSent(), None))
self.assertFalse(connection._shouldRetry(b"MYMETHOD", RequestNotSent(), None))
# This will be covered by a different ticket, since we need support
# for resettable body producers:
# self.assertTrue(connection._doRetry("PUT", RequestNotSent(), None))
def test_onlyRetryIfNoResponseReceived(self):
"""
Only L{RequestNotSent}, L{RequestTransmissionFailed} and
L{ResponseNeverReceived} exceptions cause a retry.
"""
pool = client.HTTPConnectionPool(None)
connection = client._RetryingHTTP11ClientProtocol(None, pool)
self.assertTrue(connection._shouldRetry(b"GET", RequestNotSent(), None))
self.assertTrue(
connection._shouldRetry(b"GET", RequestTransmissionFailed([]), None)
)
self.assertTrue(
connection._shouldRetry(b"GET", ResponseNeverReceived([]), None)
)
self.assertFalse(connection._shouldRetry(b"GET", ResponseFailed([]), None))
self.assertFalse(
connection._shouldRetry(b"GET", ConnectionRefusedError(), None)
)
def test_dontRetryIfFailedDueToCancel(self):
"""
If a request failed due to the operation being cancelled,
C{_shouldRetry} returns C{False} to indicate the request should not be
retried.
"""
pool = client.HTTPConnectionPool(None)
connection = client._RetryingHTTP11ClientProtocol(None, pool)
exception = ResponseNeverReceived([Failure(defer.CancelledError())])
self.assertFalse(connection._shouldRetry(b"GET", exception, None))
def test_retryIfFailedDueToNonCancelException(self):
"""
If a request failed with L{ResponseNeverReceived} due to some
arbitrary exception, C{_shouldRetry} returns C{True} to indicate the
request should be retried.
"""
pool = client.HTTPConnectionPool(None)
connection = client._RetryingHTTP11ClientProtocol(None, pool)
self.assertTrue(
connection._shouldRetry(
b"GET", ResponseNeverReceived([Failure(Exception())]), None
)
)
def test_wrappedOnPersistentReturned(self):
"""
If L{client.HTTPConnectionPool.getConnection} returns a previously
cached connection, it will get wrapped in a
L{client._RetryingHTTP11ClientProtocol}.
"""
pool = client.HTTPConnectionPool(Clock())
# Add a connection to the cache:
protocol = StubHTTPProtocol()
protocol.makeConnection(StringTransport())
pool._putConnection(123, protocol)
# Retrieve it, it should come back wrapped in a
# _RetryingHTTP11ClientProtocol:
d = pool.getConnection(123, DummyEndpoint())
def gotConnection(connection):
self.assertIsInstance(connection, client._RetryingHTTP11ClientProtocol)
self.assertIdentical(connection._clientProtocol, protocol)
return d.addCallback(gotConnection)
def test_notWrappedOnNewReturned(self):
"""
If L{client.HTTPConnectionPool.getConnection} returns a new
connection, it will be returned as is.
"""
pool = client.HTTPConnectionPool(None)
d = pool.getConnection(123, DummyEndpoint())
def gotConnection(connection):
# Don't want to use isinstance since potentially the wrapper might
# subclass it at some point:
self.assertIdentical(connection.__class__, HTTP11ClientProtocol)
return d.addCallback(gotConnection)
def retryAttempt(self, willWeRetry):
"""
Fail a first request, possibly retrying depending on argument.
"""
protocols = []
def newProtocol():
protocol = StubHTTPProtocol()
protocols.append(protocol)
return defer.succeed(protocol)
bodyProducer = object()
request = client.Request(
b"FOO", b"/", client.Headers(), bodyProducer, persistent=True
)
newProtocol()
protocol = protocols[0]
retrier = client._RetryingHTTP11ClientProtocol(protocol, newProtocol)
def _shouldRetry(m, e, bp):
self.assertEqual(m, b"FOO")
self.assertIdentical(bp, bodyProducer)
self.assertIsInstance(e, (RequestNotSent, ResponseNeverReceived))
return willWeRetry
retrier._shouldRetry = _shouldRetry
d = retrier.request(request)
# So far, one request made:
self.assertEqual(len(protocols), 1)
self.assertEqual(len(protocols[0].requests), 1)
# Fail the first request:
protocol.requests[0][1].errback(RequestNotSent())
return d, protocols
def test_retryIfShouldRetryReturnsTrue(self):
"""
L{client._RetryingHTTP11ClientProtocol} retries when
L{client._RetryingHTTP11ClientProtocol._shouldRetry} returns C{True}.
"""
d, protocols = self.retryAttempt(True)
# We retried!
self.assertEqual(len(protocols), 2)
response = object()
protocols[1].requests[0][1].callback(response)
return d.addCallback(self.assertIdentical, response)
def test_dontRetryIfShouldRetryReturnsFalse(self):
"""
L{client._RetryingHTTP11ClientProtocol} does not retry when
L{client._RetryingHTTP11ClientProtocol._shouldRetry} returns C{False}.
"""
d, protocols = self.retryAttempt(False)
# We did not retry:
self.assertEqual(len(protocols), 1)
return self.assertFailure(d, RequestNotSent)
def test_onlyRetryWithoutBody(self):
"""
L{_RetryingHTTP11ClientProtocol} only retries queries that don't have
a body.
This is an implementation restriction; if the restriction is fixed,
this test should be removed and PUT added to list of methods that
support retries.
"""
pool = client.HTTPConnectionPool(None)
connection = client._RetryingHTTP11ClientProtocol(None, pool)
self.assertTrue(connection._shouldRetry(b"GET", RequestNotSent(), None))
self.assertFalse(connection._shouldRetry(b"GET", RequestNotSent(), object()))
def test_onlyRetryOnce(self):
"""
If a L{client._RetryingHTTP11ClientProtocol} fails more than once on
an idempotent query before a response is received, it will not retry.
"""
d, protocols = self.retryAttempt(True)
self.assertEqual(len(protocols), 2)
# Fail the second request too:
protocols[1].requests[0][1].errback(ResponseNeverReceived([]))
# We didn't retry again:
self.assertEqual(len(protocols), 2)
return self.assertFailure(d, ResponseNeverReceived)
def test_dontRetryIfRetryAutomaticallyFalse(self):
"""
If L{HTTPConnectionPool.retryAutomatically} is set to C{False}, don't
wrap connections with retrying logic.
"""
pool = client.HTTPConnectionPool(Clock())
pool.retryAutomatically = False
# Add a connection to the cache:
protocol = StubHTTPProtocol()
protocol.makeConnection(StringTransport())
pool._putConnection(123, protocol)
# Retrieve it, it should come back unwrapped:
d = pool.getConnection(123, DummyEndpoint())
def gotConnection(connection):
self.assertIdentical(connection, protocol)
return d.addCallback(gotConnection)
def test_retryWithNewConnection(self):
"""
L{client.HTTPConnectionPool} creates
{client._RetryingHTTP11ClientProtocol} with a new connection factory
method that creates a new connection using the same key and endpoint
as the wrapped connection.
"""
pool = client.HTTPConnectionPool(Clock())
key = 123
endpoint = DummyEndpoint()
newConnections = []
# Override the pool's _newConnection:
def newConnection(k, e):
newConnections.append((k, e))
pool._newConnection = newConnection
# Add a connection to the cache:
protocol = StubHTTPProtocol()
protocol.makeConnection(StringTransport())
pool._putConnection(key, protocol)
# Retrieve it, it should come back wrapped in a
# _RetryingHTTP11ClientProtocol:
d = pool.getConnection(key, endpoint)
def gotConnection(connection):
self.assertIsInstance(connection, client._RetryingHTTP11ClientProtocol)
self.assertIdentical(connection._clientProtocol, protocol)
# Verify that the _newConnection method on retrying connection
# calls _newConnection on the pool:
self.assertEqual(newConnections, [])
connection._newConnection()
self.assertEqual(len(newConnections), 1)
self.assertEqual(newConnections[0][0], key)
self.assertIdentical(newConnections[0][1], endpoint)
return d.addCallback(gotConnection)
class CookieTestsMixin:
"""
Mixin for unit tests dealing with cookies.
"""
def addCookies(self, cookieJar, uri, cookies):
"""
Add a cookie to a cookie jar.
"""
response = client._FakeUrllib2Response(
client.Response(
(b"HTTP", 1, 1),
200,
b"OK",
client.Headers({b"Set-Cookie": cookies}),
None,
)
)
request = client._FakeUrllib2Request(uri)
cookieJar.extract_cookies(response, request)
return request, response
class CookieJarTests(TestCase, CookieTestsMixin):
"""
Tests for L{twisted.web.client._FakeUrllib2Response} and
L{twisted.web.client._FakeUrllib2Request}'s interactions with
L{CookieJar} instances.
"""
def makeCookieJar(self):
"""
@return: a L{CookieJar} with some sample cookies
"""
cookieJar = CookieJar()
reqres = self.addCookies(
cookieJar,
b"http://example.com:1234/foo?bar",
[b"foo=1; cow=moo; Path=/foo; Comment=hello", b"bar=2; Comment=goodbye"],
)
return cookieJar, reqres
def test_extractCookies(self):
"""
L{CookieJar.extract_cookies} extracts cookie information from
fake urllib2 response instances.
"""
jar = self.makeCookieJar()[0]
cookies = {c.name: c for c in jar}
cookie = cookies["foo"]
self.assertEqual(cookie.version, 0)
self.assertEqual(cookie.name, "foo")
self.assertEqual(cookie.value, "1")
self.assertEqual(cookie.path, "/foo")
self.assertEqual(cookie.comment, "hello")
self.assertEqual(cookie.get_nonstandard_attr("cow"), "moo")
cookie = cookies["bar"]
self.assertEqual(cookie.version, 0)
self.assertEqual(cookie.name, "bar")
self.assertEqual(cookie.value, "2")
self.assertEqual(cookie.path, "/")
self.assertEqual(cookie.comment, "goodbye")
self.assertIdentical(cookie.get_nonstandard_attr("cow"), None)
def test_sendCookie(self):
"""
L{CookieJar.add_cookie_header} adds a cookie header to a fake
urllib2 request instance.
"""
jar, (request, response) = self.makeCookieJar()
self.assertIdentical(request.get_header("Cookie", None), None)
jar.add_cookie_header(request)
self.assertEqual(request.get_header("Cookie", None), "foo=1; bar=2")
class CookieAgentTests(
TestCase, CookieTestsMixin, FakeReactorAndConnectMixin, AgentTestsMixin
):
"""
Tests for L{twisted.web.client.CookieAgent}.
"""
def makeAgent(self):
"""
@return: a new L{twisted.web.client.CookieAgent}
"""
return client.CookieAgent(
self.buildAgentForWrapperTest(self.reactor), CookieJar()
)
def setUp(self):
self.reactor = self.createReactor()
def test_emptyCookieJarRequest(self):
"""
L{CookieAgent.request} does not insert any C{'Cookie'} header into the
L{Request} object if there is no cookie in the cookie jar for the URI
being requested. Cookies are extracted from the response and stored in
the cookie jar.
"""
cookieJar = CookieJar()
self.assertEqual(list(cookieJar), [])
agent = self.buildAgentForWrapperTest(self.reactor)
cookieAgent = client.CookieAgent(agent, cookieJar)
d = cookieAgent.request(b"GET", b"http://example.com:1234/foo?bar")
def _checkCookie(ignored):
cookies = list(cookieJar)
self.assertEqual(len(cookies), 1)
self.assertEqual(cookies[0].name, "foo")
self.assertEqual(cookies[0].value, "1")
d.addCallback(_checkCookie)
req, res = self.protocol.requests.pop()
self.assertIdentical(req.headers.getRawHeaders(b"cookie"), None)
resp = client.Response(
(b"HTTP", 1, 1),
200,
b"OK",
client.Headers(
{
b"Set-Cookie": [
b"foo=1",
]
}
),
None,
)
res.callback(resp)
return d
def test_requestWithCookie(self):
"""
L{CookieAgent.request} inserts a C{'Cookie'} header into the L{Request}
object when there is a cookie matching the request URI in the cookie
jar.
"""
uri = b"http://example.com:1234/foo?bar"
cookie = b"foo=1"
cookieJar = CookieJar()
self.addCookies(cookieJar, uri, [cookie])
self.assertEqual(len(list(cookieJar)), 1)
agent = self.buildAgentForWrapperTest(self.reactor)
cookieAgent = client.CookieAgent(agent, cookieJar)
cookieAgent.request(b"GET", uri)
req, res = self.protocol.requests.pop()
self.assertEqual(req.headers.getRawHeaders(b"cookie"), [cookie])
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
def test_secureCookie(self):
"""
L{CookieAgent} is able to handle secure cookies, ie cookies which
should only be handled over https.
"""
uri = b"https://example.com:1234/foo?bar"
cookie = b"foo=1;secure"
cookieJar = CookieJar()
self.addCookies(cookieJar, uri, [cookie])
self.assertEqual(len(list(cookieJar)), 1)
agent = self.buildAgentForWrapperTest(self.reactor)
cookieAgent = client.CookieAgent(agent, cookieJar)
cookieAgent.request(b"GET", uri)
req, res = self.protocol.requests.pop()
self.assertEqual(req.headers.getRawHeaders(b"cookie"), [b"foo=1"])
def test_secureCookieOnInsecureConnection(self):
"""
If a cookie is setup as secure, it won't be sent with the request if
it's not over HTTPS.
"""
uri = b"http://example.com/foo?bar"
cookie = b"foo=1;secure"
cookieJar = CookieJar()
self.addCookies(cookieJar, uri, [cookie])
self.assertEqual(len(list(cookieJar)), 1)
agent = self.buildAgentForWrapperTest(self.reactor)
cookieAgent = client.CookieAgent(agent, cookieJar)
cookieAgent.request(b"GET", uri)
req, res = self.protocol.requests.pop()
self.assertIdentical(None, req.headers.getRawHeaders(b"cookie"))
def test_portCookie(self):
"""
L{CookieAgent} supports cookies which enforces the port number they
need to be transferred upon.
"""
uri = b"http://example.com:1234/foo?bar"
cookie = b"foo=1;port=1234"
cookieJar = CookieJar()
self.addCookies(cookieJar, uri, [cookie])
self.assertEqual(len(list(cookieJar)), 1)
agent = self.buildAgentForWrapperTest(self.reactor)
cookieAgent = client.CookieAgent(agent, cookieJar)
cookieAgent.request(b"GET", uri)
req, res = self.protocol.requests.pop()
self.assertEqual(req.headers.getRawHeaders(b"cookie"), [b"foo=1"])
def test_portCookieOnWrongPort(self):
"""
When creating a cookie with a port directive, it won't be added to the
L{cookie.CookieJar} if the URI is on a different port.
"""
uri = b"http://example.com:4567/foo?bar"
cookie = b"foo=1;port=1234"
cookieJar = CookieJar()
self.addCookies(cookieJar, uri, [cookie])
self.assertEqual(len(list(cookieJar)), 0)
class Decoder1(proxyForInterface(IResponse)): # type: ignore[misc]
"""
A test decoder to be used by L{client.ContentDecoderAgent} tests.
"""
class Decoder2(Decoder1):
"""
A test decoder to be used by L{client.ContentDecoderAgent} tests.
"""
class ContentDecoderAgentTests(TestCase, FakeReactorAndConnectMixin, AgentTestsMixin):
"""
Tests for L{client.ContentDecoderAgent}.
"""
def makeAgent(self):
"""
@return: a new L{twisted.web.client.ContentDecoderAgent}
"""
return client.ContentDecoderAgent(self.agent, [])
def setUp(self):
"""
Create an L{Agent} wrapped around a fake reactor.
"""
self.reactor = self.createReactor()
self.agent = self.buildAgentForWrapperTest(self.reactor)
def test_acceptHeaders(self):
"""
L{client.ContentDecoderAgent} sets the I{Accept-Encoding} header to the
names of the available decoder objects.
"""
agent = client.ContentDecoderAgent(
self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
)
agent.request(b"GET", b"http://example.com/foo")
protocol = self.protocol
self.assertEqual(len(protocol.requests), 1)
req, res = protocol.requests.pop()
self.assertEqual(
req.headers.getRawHeaders(b"accept-encoding"), [b"decoder1,decoder2"]
)
def test_existingHeaders(self):
"""
If there are existing I{Accept-Encoding} fields,
L{client.ContentDecoderAgent} creates a new field for the decoders it
knows about.
"""
headers = http_headers.Headers(
{b"foo": [b"bar"], b"accept-encoding": [b"fizz"]}
)
agent = client.ContentDecoderAgent(
self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
)
agent.request(b"GET", b"http://example.com/foo", headers=headers)
protocol = self.protocol
self.assertEqual(len(protocol.requests), 1)
req, res = protocol.requests.pop()
self.assertEqual(
list(sorted(req.headers.getAllRawHeaders())),
[
(b"Accept-Encoding", [b"fizz", b"decoder1,decoder2"]),
(b"Foo", [b"bar"]),
(b"Host", [b"example.com"]),
],
)
def test_plainEncodingResponse(self):
"""
If the response is not encoded despited the request I{Accept-Encoding}
headers, L{client.ContentDecoderAgent} simply forwards the response.
"""
agent = client.ContentDecoderAgent(
self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
)
deferred = agent.request(b"GET", b"http://example.com/foo")
req, res = self.protocol.requests.pop()
response = Response((b"HTTP", 1, 1), 200, b"OK", http_headers.Headers(), None)
res.callback(response)
return deferred.addCallback(self.assertIdentical, response)
def test_unsupportedEncoding(self):
"""
If an encoding unknown to the L{client.ContentDecoderAgent} is found,
the response is unchanged.
"""
agent = client.ContentDecoderAgent(
self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
)
deferred = agent.request(b"GET", b"http://example.com/foo")
req, res = self.protocol.requests.pop()
headers = http_headers.Headers(
{b"foo": [b"bar"], b"content-encoding": [b"fizz"]}
)
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
res.callback(response)
return deferred.addCallback(self.assertIdentical, response)
def test_unknownEncoding(self):
"""
When L{client.ContentDecoderAgent} encounters a decoder it doesn't know
about, it stops decoding even if another encoding is known afterwards.
"""
agent = client.ContentDecoderAgent(
self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
)
deferred = agent.request(b"GET", b"http://example.com/foo")
req, res = self.protocol.requests.pop()
headers = http_headers.Headers(
{b"foo": [b"bar"], b"content-encoding": [b"decoder1,fizz,decoder2"]}
)
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
res.callback(response)
def check(result):
self.assertNotIdentical(response, result)
self.assertIsInstance(result, Decoder2)
self.assertEqual(
[b"decoder1,fizz"], result.headers.getRawHeaders(b"content-encoding")
)
return deferred.addCallback(check)
class SimpleAgentProtocol(Protocol):
"""
A L{Protocol} to be used with an L{client.Agent} to receive data.
@ivar finished: L{Deferred} firing when C{connectionLost} is called.
@ivar made: L{Deferred} firing when C{connectionMade} is called.
@ivar received: C{list} of received data.
"""
def __init__(self):
self.made = Deferred()
self.finished = Deferred()
self.received = []
def connectionMade(self):
self.made.callback(None)
def connectionLost(self, reason):
self.finished.callback(None)
def dataReceived(self, data):
self.received.append(data)
class ContentDecoderAgentWithGzipTests(TestCase, FakeReactorAndConnectMixin):
def setUp(self):
"""
Create an L{Agent} wrapped around a fake reactor.
"""
self.reactor = self.createReactor()
agent = self.buildAgentForWrapperTest(self.reactor)
self.agent = client.ContentDecoderAgent(agent, [(b"gzip", client.GzipDecoder)])
def test_gzipEncodingResponse(self):
"""
If the response has a C{gzip} I{Content-Encoding} header,
L{GzipDecoder} wraps the response to return uncompressed data to the
user.
"""
deferred = self.agent.request(b"GET", b"http://example.com/foo")
req, res = self.protocol.requests.pop()
headers = http_headers.Headers(
{b"foo": [b"bar"], b"content-encoding": [b"gzip"]}
)
transport = StringTransport()
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
response.length = 12
res.callback(response)
compressor = zlib.compressobj(2, zlib.DEFLATED, 16 + zlib.MAX_WBITS)
data = (
compressor.compress(b"x" * 6)
+ compressor.compress(b"y" * 4)
+ compressor.flush()
)
def checkResponse(result):
self.assertNotIdentical(result, response)
self.assertEqual(result.version, (b"HTTP", 1, 1))
self.assertEqual(result.code, 200)
self.assertEqual(result.phrase, b"OK")
self.assertEqual(
list(result.headers.getAllRawHeaders()), [(b"Foo", [b"bar"])]
)
self.assertEqual(result.length, UNKNOWN_LENGTH)
self.assertRaises(AttributeError, getattr, result, "unknown")
response._bodyDataReceived(data[:5])
response._bodyDataReceived(data[5:])
response._bodyDataFinished()
protocol = SimpleAgentProtocol()
result.deliverBody(protocol)
self.assertEqual(protocol.received, [b"x" * 6 + b"y" * 4])
return defer.gatherResults([protocol.made, protocol.finished])
deferred.addCallback(checkResponse)
return deferred
def test_brokenContent(self):
"""
If the data received by the L{GzipDecoder} isn't valid gzip-compressed
data, the call to C{deliverBody} fails with a C{zlib.error}.
"""
deferred = self.agent.request(b"GET", b"http://example.com/foo")
req, res = self.protocol.requests.pop()
headers = http_headers.Headers(
{b"foo": [b"bar"], b"content-encoding": [b"gzip"]}
)
transport = StringTransport()
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
response.length = 12
res.callback(response)
data = b"not gzipped content"
def checkResponse(result):
response._bodyDataReceived(data)
result.deliverBody(Protocol())
deferred.addCallback(checkResponse)
self.assertFailure(deferred, client.ResponseFailed)
def checkFailure(error):
error.reasons[0].trap(zlib.error)
self.assertIsInstance(error.response, Response)
return deferred.addCallback(checkFailure)
def test_flushData(self):
"""
When the connection with the server is lost, the gzip protocol calls
C{flush} on the zlib decompressor object to get uncompressed data which
may have been buffered.
"""
class decompressobj:
def __init__(self, wbits):
pass
def decompress(self, data):
return b"x"
def flush(self):
return b"y"
oldDecompressObj = zlib.decompressobj
zlib.decompressobj = decompressobj
self.addCleanup(setattr, zlib, "decompressobj", oldDecompressObj)
deferred = self.agent.request(b"GET", b"http://example.com/foo")
req, res = self.protocol.requests.pop()
headers = http_headers.Headers({b"content-encoding": [b"gzip"]})
transport = StringTransport()
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
res.callback(response)
def checkResponse(result):
response._bodyDataReceived(b"data")
response._bodyDataFinished()
protocol = SimpleAgentProtocol()
result.deliverBody(protocol)
self.assertEqual(protocol.received, [b"x", b"y"])
return defer.gatherResults([protocol.made, protocol.finished])
deferred.addCallback(checkResponse)
return deferred
def test_flushError(self):
"""
If the C{flush} call in C{connectionLost} fails, the C{zlib.error}
exception is caught and turned into a L{ResponseFailed}.
"""
class decompressobj:
def __init__(self, wbits):
pass
def decompress(self, data):
return b"x"
def flush(self):
raise zlib.error()
oldDecompressObj = zlib.decompressobj
zlib.decompressobj = decompressobj
self.addCleanup(setattr, zlib, "decompressobj", oldDecompressObj)
deferred = self.agent.request(b"GET", b"http://example.com/foo")
req, res = self.protocol.requests.pop()
headers = http_headers.Headers({b"content-encoding": [b"gzip"]})
transport = StringTransport()
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
res.callback(response)
def checkResponse(result):
response._bodyDataReceived(b"data")
response._bodyDataFinished()
protocol = SimpleAgentProtocol()
result.deliverBody(protocol)
self.assertEqual(protocol.received, [b"x", b"y"])
return defer.gatherResults([protocol.made, protocol.finished])
deferred.addCallback(checkResponse)
self.assertFailure(deferred, client.ResponseFailed)
def checkFailure(error):
error.reasons[1].trap(zlib.error)
self.assertIsInstance(error.response, Response)
return deferred.addCallback(checkFailure)
class ProxyAgentTests(TestCase, FakeReactorAndConnectMixin, AgentTestsMixin):
"""
Tests for L{client.ProxyAgent}.
"""
def makeAgent(self):
"""
@return: a new L{twisted.web.client.ProxyAgent}
"""
return client.ProxyAgent(
TCP4ClientEndpoint(self.reactor, "127.0.0.1", 1234), self.reactor
)
def setUp(self):
self.reactor = self.createReactor()
self.agent = client.ProxyAgent(
TCP4ClientEndpoint(self.reactor, "bar", 5678), self.reactor
)
oldEndpoint = self.agent._proxyEndpoint
self.agent._proxyEndpoint = self.StubEndpoint(oldEndpoint, self)
def test_nonBytesMethod(self):
"""
L{ProxyAgent.request} raises L{TypeError} when the C{method} argument
isn't L{bytes}.
"""
self.assertRaises(TypeError, self.agent.request, "GET", b"http://foo.example/")
def test_proxyRequest(self):
"""
L{client.ProxyAgent} issues an HTTP request against the proxy, with the
full URI as path, when C{request} is called.
"""
headers = http_headers.Headers({b"foo": [b"bar"]})
# Just going to check the body for identity, so it doesn't need to be
# real.
body = object()
self.agent.request(b"GET", b"http://example.com:1234/foo?bar", headers, body)
host, port, factory = self.reactor.tcpClients.pop()[:3]
self.assertEqual(host, "bar")
self.assertEqual(port, 5678)
self.assertIsInstance(factory._wrappedFactory, client._HTTP11ClientFactory)
protocol = self.protocol
# The request should be issued.
self.assertEqual(len(protocol.requests), 1)
req, res = protocol.requests.pop()
self.assertIsInstance(req, Request)
self.assertEqual(req.method, b"GET")
self.assertEqual(req.uri, b"http://example.com:1234/foo?bar")
self.assertEqual(
req.headers,
http_headers.Headers({b"foo": [b"bar"], b"host": [b"example.com:1234"]}),
)
self.assertIdentical(req.bodyProducer, body)
def test_nonPersistent(self):
"""
C{ProxyAgent} connections are not persistent by default.
"""
self.assertEqual(self.agent._pool.persistent, False)
def test_connectUsesConnectionPool(self):
"""
When a connection is made by the C{ProxyAgent}, it uses its pool's
C{getConnection} method to do so, with the endpoint it was constructed
with and a key of C{("http-proxy", endpoint)}.
"""
endpoint = DummyEndpoint()
class DummyPool:
connected = False
persistent = False
def getConnection(this, key, ep):
this.connected = True
self.assertIdentical(ep, endpoint)
# The key is *not* tied to the final destination, but only to
# the address of the proxy, since that's where *we* are
# connecting:
self.assertEqual(key, ("http-proxy", endpoint))
return defer.succeed(StubHTTPProtocol())
pool = DummyPool()
agent = client.ProxyAgent(endpoint, self.reactor, pool=pool)
self.assertIdentical(pool, agent._pool)
agent.request(b"GET", b"http://foo/")
self.assertEqual(agent._pool.connected, True)
SENSITIVE_HEADERS = [
b"authorization",
b"cookie",
b"cookie2",
b"proxy-authorization",
b"www-authenticate",
]
class _RedirectAgentTestsMixin(testMixinClass):
"""
Test cases mixin for L{RedirectAgentTests} and
L{BrowserLikeRedirectAgentTests}.
"""
agent: IAgent
reactor: MemoryReactorClock
protocol: StubHTTPProtocol
def test_noRedirect(self):
"""
L{client.RedirectAgent} behaves like L{client.Agent} if the response
doesn't contain a redirect.
"""
deferred = self.agent.request(b"GET", b"http://example.com/foo")
req, res = self.protocol.requests.pop()
headers = http_headers.Headers()
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
res.callback(response)
self.assertEqual(0, len(self.protocol.requests))
result = self.successResultOf(deferred)
self.assertIdentical(response, result)
self.assertIdentical(result.previousResponse, None)
def _testRedirectDefault(
self,
code: int,
crossScheme: bool = False,
crossDomain: bool = False,
crossPort: bool = False,
requestHeaders: Optional[Headers] = None,
) -> Request:
"""
When getting a redirect, L{client.RedirectAgent} follows the URL
specified in the L{Location} header field and make a new request.
@param code: HTTP status code.
"""
startDomain = b"example.com"
startScheme = b"https" if ssl is not None else b"http"
startPort = 80 if startScheme == b"http" else 443
self.agent.request(
b"GET", startScheme + b"://" + startDomain + b"/foo", headers=requestHeaders
)
host, port = self.reactor.tcpClients.pop()[:2]
self.assertEqual(EXAMPLE_COM_IP, host)
self.assertEqual(startPort, port)
req, res = self.protocol.requests.pop()
# If possible (i.e.: TLS support is present), run the test with a
# cross-scheme redirect to verify that the scheme is honored; if not,
# let's just make sure it works at all.
targetScheme = startScheme
targetDomain = startDomain
targetPort = startPort
if crossScheme:
if ssl is None:
raise SkipTest(
"Cross-scheme redirects can't be tested without TLS support."
)
targetScheme = b"https" if startScheme == b"http" else b"http"
targetPort = 443 if startPort == 80 else 80
portSyntax = b""
if crossPort:
targetPort = 8443
portSyntax = b":8443"
targetDomain = b"example.net" if crossDomain else startDomain
locationValue = targetScheme + b"://" + targetDomain + portSyntax + b"/bar"
headers = http_headers.Headers({b"location": [locationValue]})
response = Response((b"HTTP", 1, 1), code, b"OK", headers, None)
res.callback(response)
req2, res2 = self.protocol.requests.pop()
self.assertEqual(b"GET", req2.method)
self.assertEqual(b"/bar", req2.uri)
host, port = self.reactor.tcpClients.pop()[:2]
self.assertEqual(EXAMPLE_NET_IP if crossDomain else EXAMPLE_COM_IP, host)
self.assertEqual(targetPort, port)
return req2
def test_redirect301(self):
"""
L{client.RedirectAgent} follows redirects on status code 301.
"""
self._testRedirectDefault(301)
def test_redirect301Scheme(self):
"""
L{client.RedirectAgent} follows cross-scheme redirects.
"""
self._testRedirectDefault(
301,
crossScheme=True,
)
def test_redirect302(self):
"""
L{client.RedirectAgent} follows redirects on status code 302.
"""
self._testRedirectDefault(302)
def test_redirect307(self):
"""
L{client.RedirectAgent} follows redirects on status code 307.
"""
self._testRedirectDefault(307)
def test_redirect308(self):
"""
L{client.RedirectAgent} follows redirects on status code 308.
"""
self._testRedirectDefault(308)
def _sensitiveHeadersTest(
self, expectedHostHeader: bytes = b"example.com", **crossKwargs: bool
) -> None:
"""
L{client.RedirectAgent} scrubs sensitive headers when redirecting
between differing origins.
"""
sensitiveHeaderValues = {
b"authorization": [b"sensitive-authnz"],
b"cookie": [b"sensitive-cookie-data"],
b"cookie2": [b"sensitive-cookie2-data"],
b"proxy-authorization": [b"sensitive-proxy-auth"],
b"wWw-auThentiCate": [b"sensitive-authn"],
b"x-custom-sensitive": [b"sensitive-custom"],
}
otherHeaderValues = {b"x-random-header": [b"x-random-value"]}
allHeaders = Headers({**sensitiveHeaderValues, **otherHeaderValues})
redirected = self._testRedirectDefault(301, requestHeaders=allHeaders)
def normHeaders(headers: Headers) -> dict:
return {k.lower(): v for (k, v) in headers.getAllRawHeaders()}
sameOriginHeaders = normHeaders(redirected.headers)
self.assertEquals(
sameOriginHeaders,
{
b"host": [b"example.com"],
**normHeaders(allHeaders),
},
)
redirectedElsewhere = self._testRedirectDefault(
301,
**crossKwargs,
requestHeaders=Headers({**sensitiveHeaderValues, **otherHeaderValues}),
)
otherOriginHeaders = normHeaders(redirectedElsewhere.headers)
self.assertEquals(
otherOriginHeaders,
{
b"host": [expectedHostHeader],
**normHeaders(Headers(otherHeaderValues)),
},
)
def test_crossDomainHeaders(self) -> None:
"""
L{client.RedirectAgent} scrubs sensitive headers when redirecting
between differing domains.
"""
self._sensitiveHeadersTest(crossDomain=True, expectedHostHeader=b"example.net")
def test_crossPortHeaders(self) -> None:
"""
L{client.RedirectAgent} scrubs sensitive headers when redirecting
between differing ports.
"""
self._sensitiveHeadersTest(
crossPort=True, expectedHostHeader=b"example.com:8443"
)
def test_crossSchemeHeaders(self) -> None:
"""
L{client.RedirectAgent} scrubs sensitive headers when redirecting
between differing schemes.
"""
self._sensitiveHeadersTest(crossScheme=True)
def _testRedirectToGet(self, code, method):
"""
L{client.RedirectAgent} changes the method to I{GET} when getting
a redirect on a non-I{GET} request.
@param code: HTTP status code.
@param method: HTTP request method.
"""
self.agent.request(method, b"http://example.com/foo")
req, res = self.protocol.requests.pop()
headers = http_headers.Headers({b"location": [b"http://example.com/bar"]})
response = Response((b"HTTP", 1, 1), code, b"OK", headers, None)
res.callback(response)
req2, res2 = self.protocol.requests.pop()
self.assertEqual(b"GET", req2.method)
self.assertEqual(b"/bar", req2.uri)
def test_redirect303(self):
"""
L{client.RedirectAgent} changes the method to I{GET} when getting a 303
redirect on a I{POST} request.
"""
self._testRedirectToGet(303, b"POST")
def test_noLocationField(self):
"""
If no L{Location} header field is found when getting a redirect,
L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping a
L{error.RedirectWithNoLocation} exception.
"""
deferred = self.agent.request(b"GET", b"http://example.com/foo")
req, res = self.protocol.requests.pop()
headers = http_headers.Headers()
response = Response((b"HTTP", 1, 1), 301, b"OK", headers, None)
res.callback(response)
fail = self.failureResultOf(deferred, client.ResponseFailed)
fail.value.reasons[0].trap(error.RedirectWithNoLocation)
self.assertEqual(b"http://example.com/foo", fail.value.reasons[0].value.uri)
self.assertEqual(301, fail.value.response.code)
def _testPageRedirectFailure(self, code, method):
"""
When getting a redirect on an unsupported request method,
L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
a L{error.PageRedirect} exception.
@param code: HTTP status code.
@param method: HTTP request method.
"""
deferred = self.agent.request(method, b"http://example.com/foo")
req, res = self.protocol.requests.pop()
headers = http_headers.Headers()
response = Response((b"HTTP", 1, 1), code, b"OK", headers, None)
res.callback(response)
fail = self.failureResultOf(deferred, client.ResponseFailed)
fail.value.reasons[0].trap(error.PageRedirect)
self.assertEqual(
b"http://example.com/foo", fail.value.reasons[0].value.location
)
self.assertEqual(code, fail.value.response.code)
def test_307OnPost(self):
"""
When getting a 307 redirect on a I{POST} request,
L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
a L{error.PageRedirect} exception.
"""
self._testPageRedirectFailure(307, b"POST")
def test_redirectLimit(self):
"""
If the limit of redirects specified to L{client.RedirectAgent} is
reached, the deferred fires with L{ResponseFailed} error wrapping
a L{InfiniteRedirection} exception.
"""
agent = self.buildAgentForWrapperTest(self.reactor)
redirectAgent = client.RedirectAgent(agent, 1)
deferred = redirectAgent.request(b"GET", b"http://example.com/foo")
req, res = self.protocol.requests.pop()
headers = http_headers.Headers({b"location": [b"http://example.com/bar"]})
response = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
res.callback(response)
req2, res2 = self.protocol.requests.pop()
response2 = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
res2.callback(response2)
fail = self.failureResultOf(deferred, client.ResponseFailed)
fail.value.reasons[0].trap(error.InfiniteRedirection)
self.assertEqual(
b"http://example.com/foo", fail.value.reasons[0].value.location
)
self.assertEqual(302, fail.value.response.code)
def _testRedirectURI(self, uri, location, finalURI):
"""
When L{client.RedirectAgent} encounters a relative redirect I{URI}, it
is resolved against the request I{URI} before following the redirect.
@param uri: Request URI.
@param location: I{Location} header redirect URI.
@param finalURI: Expected final URI.
"""
self.agent.request(b"GET", uri)
req, res = self.protocol.requests.pop()
headers = http_headers.Headers({b"location": [location]})
response = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
res.callback(response)
req2, res2 = self.protocol.requests.pop()
self.assertEqual(b"GET", req2.method)
self.assertEqual(finalURI, req2.absoluteURI)
def test_relativeURI(self):
"""
L{client.RedirectAgent} resolves and follows relative I{URI}s in
redirects, preserving query strings.
"""
self._testRedirectURI(
b"http://example.com/foo/bar", b"baz", b"http://example.com/foo/baz"
)
self._testRedirectURI(
b"http://example.com/foo/bar", b"/baz", b"http://example.com/baz"
)
self._testRedirectURI(
b"http://example.com/foo/bar", b"/baz?a", b"http://example.com/baz?a"
)
def test_relativeURIPreserveFragments(self):
"""
L{client.RedirectAgent} resolves and follows relative I{URI}s in
redirects, preserving fragments in way that complies with the HTTP 1.1
bis draft.
@see: U{https://tools.ietf.org/html/draft-ietf-httpbis-p2-semantics-22#section-7.1.2}
"""
self._testRedirectURI(
b"http://example.com/foo/bar#frag",
b"/baz?a",
b"http://example.com/baz?a#frag",
)
self._testRedirectURI(
b"http://example.com/foo/bar",
b"/baz?a#frag2",
b"http://example.com/baz?a#frag2",
)
def test_relativeURISchemeRelative(self):
"""
L{client.RedirectAgent} resolves and follows scheme relative I{URI}s in
redirects, replacing the hostname and port when required.
"""
self._testRedirectURI(
b"http://example.com/foo/bar", b"//foo.com/baz", b"http://foo.com/baz"
)
self._testRedirectURI(
b"http://example.com/foo/bar", b"//foo.com:81/baz", b"http://foo.com:81/baz"
)
def test_responseHistory(self):
"""
L{Response.response} references the previous L{Response} from
a redirect, or L{None} if there was no previous response.
"""
agent = self.buildAgentForWrapperTest(self.reactor)
redirectAgent = client.RedirectAgent(agent)
deferred = redirectAgent.request(b"GET", b"http://example.com/foo")
redirectReq, redirectRes = self.protocol.requests.pop()
headers = http_headers.Headers({b"location": [b"http://example.com/bar"]})
redirectResponse = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
redirectRes.callback(redirectResponse)
req, res = self.protocol.requests.pop()
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
res.callback(response)
finalResponse = self.successResultOf(deferred)
self.assertIdentical(finalResponse.previousResponse, redirectResponse)
self.assertIdentical(redirectResponse.previousResponse, None)
class RedirectAgentTests(
FakeReactorAndConnectMixin,
_RedirectAgentTestsMixin,
AgentTestsMixin,
runtimeTestCase,
):
"""
Tests for L{client.RedirectAgent}.
"""
def makeAgent(self):
"""
@return: a new L{twisted.web.client.RedirectAgent}
"""
return client.RedirectAgent(
self.buildAgentForWrapperTest(self.reactor),
sensitiveHeaderNames=[b"X-Custom-sensitive"],
)
def setUp(self):
self.reactor = self.createReactor()
self.agent = self.makeAgent()
def test_301OnPost(self):
"""
When getting a 301 redirect on a I{POST} request,
L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
a L{error.PageRedirect} exception.
"""
self._testPageRedirectFailure(301, b"POST")
def test_302OnPost(self):
"""
When getting a 302 redirect on a I{POST} request,
L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
a L{error.PageRedirect} exception.
"""
self._testPageRedirectFailure(302, b"POST")
class BrowserLikeRedirectAgentTests(
FakeReactorAndConnectMixin,
_RedirectAgentTestsMixin,
AgentTestsMixin,
runtimeTestCase,
):
"""
Tests for L{client.BrowserLikeRedirectAgent}.
"""
def makeAgent(self):
"""
@return: a new L{twisted.web.client.BrowserLikeRedirectAgent}
"""
return client.BrowserLikeRedirectAgent(
self.buildAgentForWrapperTest(self.reactor),
sensitiveHeaderNames=[b"x-Custom-sensitive"],
)
def setUp(self):
self.reactor = self.createReactor()
self.agent = self.makeAgent()
def test_redirectToGet301(self):
"""
L{client.BrowserLikeRedirectAgent} changes the method to I{GET} when
getting a 302 redirect on a I{POST} request.
"""
self._testRedirectToGet(301, b"POST")
def test_redirectToGet302(self):
"""
L{client.BrowserLikeRedirectAgent} changes the method to I{GET} when
getting a 302 redirect on a I{POST} request.
"""
self._testRedirectToGet(302, b"POST")
class AbortableStringTransport(StringTransport):
"""
A version of L{StringTransport} that supports C{abortConnection}.
"""
# This should be replaced by a common version in #6530.
aborting = False
def abortConnection(self):
"""
A testable version of the C{ITCPTransport.abortConnection} method.
Since this is a special case of closing the connection,
C{loseConnection} is also called.
"""
self.aborting = True
self.loseConnection()
class DummyResponse:
"""
Fake L{IResponse} for testing readBody that captures the protocol passed to
deliverBody and uses it to make a connection with a transport.
@ivar protocol: After C{deliverBody} is called, the protocol it was called
with.
@ivar transport: An instance created by calling C{transportFactory} which
is used by L{DummyResponse.protocol} to make a connection.
"""
code = 200
phrase = b"OK"
def __init__(self, headers=None, transportFactory=AbortableStringTransport):
"""
@param headers: The headers for this response. If L{None}, an empty
L{Headers} instance will be used.
@type headers: L{Headers}
@param transportFactory: A callable used to construct the transport.
"""
if headers is None:
headers = Headers()
self.headers = headers
self.transport = transportFactory()
def deliverBody(self, protocol):
"""
Record the given protocol and use it to make a connection with
L{DummyResponse.transport}.
"""
self.protocol = protocol
self.protocol.makeConnection(self.transport)
class AlreadyCompletedDummyResponse(DummyResponse):
"""
A dummy response that has already had its transport closed.
"""
def deliverBody(self, protocol):
"""
Make the connection, then remove the transport.
"""
self.protocol = protocol
self.protocol.makeConnection(self.transport)
self.protocol.transport = None
class ReadBodyTests(TestCase):
"""
Tests for L{client.readBody}
"""
def test_success(self):
"""
L{client.readBody} returns a L{Deferred} which fires with the complete
body of the L{IResponse} provider passed to it.
"""
response = DummyResponse()
d = client.readBody(response)
response.protocol.dataReceived(b"first")
response.protocol.dataReceived(b"second")
response.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.successResultOf(d), b"firstsecond")
def test_cancel(self):
"""
When cancelling the L{Deferred} returned by L{client.readBody}, the
connection to the server will be aborted.
"""
response = DummyResponse()
deferred = client.readBody(response)
deferred.cancel()
self.failureResultOf(deferred, defer.CancelledError)
self.assertTrue(response.transport.aborting)
def test_withPotentialDataLoss(self):
"""
If the full body of the L{IResponse} passed to L{client.readBody} is
not definitely received, the L{Deferred} returned by L{client.readBody}
fires with a L{Failure} wrapping L{client.PartialDownloadError} with
the content that was received.
"""
response = DummyResponse()
d = client.readBody(response)
response.protocol.dataReceived(b"first")
response.protocol.dataReceived(b"second")
response.protocol.connectionLost(Failure(PotentialDataLoss()))
failure = self.failureResultOf(d)
failure.trap(client.PartialDownloadError)
self.assertEqual(
{
"status": failure.value.status,
"message": failure.value.message,
"body": failure.value.response,
},
{
"status": b"200",
"message": b"OK",
"body": b"firstsecond",
},
)
def test_otherErrors(self):
"""
If there is an exception other than L{client.PotentialDataLoss} while
L{client.readBody} is collecting the response body, the L{Deferred}
returned by {client.readBody} fires with that exception.
"""
response = DummyResponse()
d = client.readBody(response)
response.protocol.dataReceived(b"first")
response.protocol.connectionLost(Failure(ConnectionLost("mystery problem")))
reason = self.failureResultOf(d)
reason.trap(ConnectionLost)
self.assertEqual(reason.value.args, ("mystery problem",))
def test_deprecatedTransport(self):
"""
Calling L{client.readBody} with a transport that does not implement
L{twisted.internet.interfaces.ITCPTransport} produces a deprecation
warning, but no exception when cancelling.
"""
response = DummyResponse(transportFactory=StringTransport)
response.transport.abortConnection = None
d = self.assertWarns(
DeprecationWarning,
"Using readBody with a transport that does not have an "
"abortConnection method",
__file__,
lambda: client.readBody(response),
)
d.cancel()
self.failureResultOf(d, defer.CancelledError)
def test_deprecatedTransportNoWarning(self):
"""
Calling L{client.readBody} with a response that has already had its
transport closed (eg. for a very small request) will not trigger a
deprecation warning.
"""
response = AlreadyCompletedDummyResponse()
client.readBody(response)
warnings = self.flushWarnings()
self.assertEqual(len(warnings), 0)
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
class HostnameCachingHTTPSPolicyTests(TestCase):
def test_cacheIsUsed(self):
"""
Verify that the connection creator is added to the
policy's cache, and that it is reused on subsequent calls
to creatorForNetLoc.
"""
trustRoot = CustomOpenSSLTrustRoot()
wrappedPolicy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
policy = HostnameCachingHTTPSPolicy(wrappedPolicy)
creator = policy.creatorForNetloc(b"foo", 1589)
self.assertTrue(trustRoot.called)
trustRoot.called = False
self.assertEquals(1, len(policy._cache))
connection = creator.clientConnectionForTLS(None)
self.assertIs(trustRoot.context, connection.get_context())
policy.creatorForNetloc(b"foo", 1589)
self.assertFalse(trustRoot.called)
def test_cacheRemovesOldest(self):
"""
Verify that when the cache is full, and a new entry is added,
the oldest entry is removed.
"""
trustRoot = CustomOpenSSLTrustRoot()
wrappedPolicy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
policy = HostnameCachingHTTPSPolicy(wrappedPolicy)
for i in range(0, 20):
hostname = "host" + str(i)
policy.creatorForNetloc(hostname.encode("ascii"), 8675)
# Force host0, which was the first, to be the most recently used
host0 = "host0"
policy.creatorForNetloc(host0.encode("ascii"), 309)
self.assertIn(host0, policy._cache)
self.assertEquals(20, len(policy._cache))
hostn = "new"
policy.creatorForNetloc(hostn.encode("ascii"), 309)
host1 = "host1"
self.assertNotIn(host1, policy._cache)
self.assertEquals(20, len(policy._cache))
self.assertIn(hostn, policy._cache)
self.assertIn(host0, policy._cache)
# Accessing an item repeatedly does not corrupt the LRU.
for _ in range(20):
policy.creatorForNetloc(host0.encode("ascii"), 8675)
hostNPlus1 = "new1"
policy.creatorForNetloc(hostNPlus1.encode("ascii"), 800)
self.assertNotIn("host2", policy._cache)
self.assertEquals(20, len(policy._cache))
self.assertIn(hostNPlus1, policy._cache)
self.assertIn(hostn, policy._cache)
self.assertIn(host0, policy._cache)
def test_changeCacheSize(self):
"""
Verify that changing the cache size results in a policy that
respects the new cache size and not the default.
"""
trustRoot = CustomOpenSSLTrustRoot()
wrappedPolicy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
policy = HostnameCachingHTTPSPolicy(wrappedPolicy, cacheSize=5)
for i in range(0, 5):
hostname = "host" + str(i)
policy.creatorForNetloc(hostname.encode("ascii"), 8675)
first = "host0"
self.assertIn(first, policy._cache)
self.assertEquals(5, len(policy._cache))
hostn = "new"
policy.creatorForNetloc(hostn.encode("ascii"), 309)
self.assertNotIn(first, policy._cache)
self.assertEquals(5, len(policy._cache))
self.assertIn(hostn, policy._cache)
class RequestMethodInjectionTests(
MethodInjectionTestsMixin,
SynchronousTestCase,
):
"""
Test L{client.Request} against HTTP method injections.
"""
def attemptRequestWithMaliciousMethod(self, method):
"""
Attempt a request with the provided method.
@param method: see L{MethodInjectionTestsMixin}
"""
client.Request(
method=method,
uri=b"http://twisted.invalid",
headers=http_headers.Headers(),
bodyProducer=None,
)
class RequestWriteToMethodInjectionTests(
MethodInjectionTestsMixin,
SynchronousTestCase,
):
"""
Test L{client.Request.writeTo} against HTTP method injections.
"""
def attemptRequestWithMaliciousMethod(self, method):
"""
Attempt a request with the provided method.
@param method: see L{MethodInjectionTestsMixin}
"""
headers = http_headers.Headers({b"Host": [b"twisted.invalid"]})
req = client.Request(
method=b"GET",
uri=b"http://twisted.invalid",
headers=headers,
bodyProducer=None,
)
req.method = method
req.writeTo(StringTransport())
class RequestURIInjectionTests(
URIInjectionTestsMixin,
SynchronousTestCase,
):
"""
Test L{client.Request} against HTTP URI injections.
"""
def attemptRequestWithMaliciousURI(self, uri):
"""
Attempt a request with the provided URI.
@param method: see L{URIInjectionTestsMixin}
"""
client.Request(
method=b"GET",
uri=uri,
headers=http_headers.Headers(),
bodyProducer=None,
)
class RequestWriteToURIInjectionTests(
URIInjectionTestsMixin,
SynchronousTestCase,
):
"""
Test L{client.Request.writeTo} against HTTP method injections.
"""
def attemptRequestWithMaliciousURI(self, uri):
"""
Attempt a request with the provided method.
@param method: see L{URIInjectionTestsMixin}
"""
headers = http_headers.Headers({b"Host": [b"twisted.invalid"]})
req = client.Request(
method=b"GET",
uri=b"http://twisted.invalid",
headers=headers,
bodyProducer=None,
)
req.uri = uri
req.writeTo(StringTransport())
Zerion Mini Shell 1.0