[HttpClient] preserve the identity of responses streamed by TraceableHttpClient

This commit is contained in:
Nicolas Grekas 2020-05-09 22:11:42 +02:00
parent 2ed6a0d74c
commit afc44dae16
3 changed files with 29 additions and 12 deletions

View File

@ -14,6 +14,7 @@ namespace Symfony\Component\HttpClient\Response;
use Symfony\Component\HttpClient\Exception\ClientException; use Symfony\Component\HttpClient\Exception\ClientException;
use Symfony\Component\HttpClient\Exception\RedirectionException; use Symfony\Component\HttpClient\Exception\RedirectionException;
use Symfony\Component\HttpClient\Exception\ServerException; use Symfony\Component\HttpClient\Exception\ServerException;
use Symfony\Component\HttpClient\TraceableHttpClient;
use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface; use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface;
use Symfony\Contracts\HttpClient\Exception\RedirectionExceptionInterface; use Symfony\Contracts\HttpClient\Exception\RedirectionExceptionInterface;
use Symfony\Contracts\HttpClient\Exception\ServerExceptionInterface; use Symfony\Contracts\HttpClient\Exception\ServerExceptionInterface;
@ -105,6 +106,28 @@ class TraceableResponse implements ResponseInterface
return StreamWrapper::createResource($this->response, $this->client); return StreamWrapper::createResource($this->response, $this->client);
} }
/**
* @internal
*/
public static function stream(HttpClientInterface $client, iterable $responses, ?float $timeout): \Generator
{
$wrappedResponses = [];
$traceableMap = new \SplObjectStorage();
foreach ($responses as $r) {
if (!$r instanceof self) {
throw new \TypeError(sprintf('"%s::stream()" expects parameter 1 to be an iterable of TraceableResponse objects, "%s" given.', TraceableHttpClient::class, get_debug_type($r)));
}
$traceableMap[$r->response] = $r;
$wrappedResponses[] = $r->response;
}
foreach ($client->stream($wrappedResponses, $timeout) as $r => $chunk) {
yield $traceableMap[$r] => $chunk;
}
}
private function checkStatusCode($code) private function checkStatusCode($code)
{ {
if (500 <= $code) { if (500 <= $code) {

View File

@ -88,11 +88,12 @@ class TraceableHttpClientTest extends TestCase
TestHttpServer::start(); TestHttpServer::start();
$sut = new TraceableHttpClient(new NativeHttpClient()); $sut = new TraceableHttpClient(new NativeHttpClient());
$chunked = $sut->request('GET', 'http://localhost:8057/chunked'); $response = $sut->request('GET', 'http://localhost:8057/chunked');
$chunks = []; $chunks = [];
foreach ($sut->stream($chunked) as $response) { foreach ($sut->stream($response) as $r => $chunk) {
$chunks[] = $response->getContent(); $chunks[] = $chunk->getContent();
} }
$this->assertSame($response, $r);
$this->assertGreaterThan(1, \count($chunks)); $this->assertGreaterThan(1, \count($chunks));
$this->assertSame('Symfony is awesome!', implode('', $chunks)); $this->assertSame('Symfony is awesome!', implode('', $chunks));
} }

View File

@ -13,6 +13,7 @@ namespace Symfony\Component\HttpClient;
use Psr\Log\LoggerAwareInterface; use Psr\Log\LoggerAwareInterface;
use Psr\Log\LoggerInterface; use Psr\Log\LoggerInterface;
use Symfony\Component\HttpClient\Response\ResponseStream;
use Symfony\Component\HttpClient\Response\TraceableResponse; use Symfony\Component\HttpClient\Response\TraceableResponse;
use Symfony\Contracts\HttpClient\HttpClientInterface; use Symfony\Contracts\HttpClient\HttpClientInterface;
use Symfony\Contracts\HttpClient\ResponseInterface; use Symfony\Contracts\HttpClient\ResponseInterface;
@ -70,15 +71,7 @@ final class TraceableHttpClient implements HttpClientInterface, ResetInterface,
throw new \TypeError(sprintf('"%s()" expects parameter 1 to be an iterable of TraceableResponse objects, "%s" given.', __METHOD__, get_debug_type($responses))); throw new \TypeError(sprintf('"%s()" expects parameter 1 to be an iterable of TraceableResponse objects, "%s" given.', __METHOD__, get_debug_type($responses)));
} }
return $this->client->stream(\Closure::bind(static function () use ($responses) { return new ResponseStream(TraceableResponse::stream($this->client, $responses, $timeout));
foreach ($responses as $k => $r) {
if (!$r instanceof TraceableResponse) {
throw new \TypeError(sprintf('"%s()" expects parameter 1 to be an iterable of TraceableResponse objects, "%s" given.', __METHOD__, get_debug_type($r)));
}
yield $k => $r->response;
}
}, null, TraceableResponse::class)(), $timeout);
} }
public function getTracedRequests(): array public function getTracedRequests(): array