Skip to content

Commit

Permalink
Compact the implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Gallaecio committed Nov 14, 2023
1 parent 9877070 commit 39cf8a2
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 65 deletions.
22 changes: 6 additions & 16 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@ Then, set up the scrapy-zyte-api integration:
}
DOWNLOADER_MIDDLEWARES = {
"scrapy_zyte_api.ScrapyZyteAPIDownloaderMiddleware": 1000,
"scrapy_zyte_api.ForbiddenDomainDownloaderMiddleware": 1100,
}
REQUEST_FINGERPRINTER_CLASS = "scrapy_zyte_api.ScrapyZyteAPIRequestFingerprinter"
SPIDER_MIDDLEWARES = {
"scrapy_zyte_api.ForbiddenDomainSpiderMiddleware": 100,
"scrapy_zyte_api.ScrapyZyteAPISpiderMiddleware": 100,
}
TWISTED_REACTOR = "twisted.internet.asyncioreactor.AsyncioSelectorReactor"
Expand All @@ -95,6 +94,11 @@ To enable this plugin:
<https://docs.scrapy.org/en/latest/topics/settings.html#downloader-middlewares>`_
Scrapy setting with any value, e.g. ``1000``.

- Add ``"scrapy_zyte_api.ScrapyZyteAPISpiderMiddleware"`` to the
`SPIDER_MIDDLEWARES
<https://docs.scrapy.org/en/latest/topics/settings.html#spider-middlewares>`_
Scrapy setting with any value, e.g. ``100``.

- Set the `REQUEST_FINGERPRINTER_CLASS
<https://docs.scrapy.org/en/latest/topics/request-response.html#request-fingerprinter-class>`_
Scrapy setting to ``"scrapy_zyte_api.ScrapyZyteAPIRequestFingerprinter"``.
Expand Down Expand Up @@ -139,20 +143,6 @@ If you want to use scrapy-poet integration, add a provider to
"scrapy_zyte_api.providers.ZyteApiProvider": 1100,
}
To have your spiders finish with ``failed_forbidden_domain`` as a close reason
when all start URLs belong to domains forbidden by Zyte API, edit your Scrapy
settings further as follows:

- Add ``"scrapy_zyte_api.ForbiddenDomainDownloaderMiddleware"`` to the
`DOWNLOADER_MIDDLEWARES
<https://docs.scrapy.org/en/latest/topics/settings.html#downloader-middlewares>`_
Scrapy setting with any value, e.g. ``1100``.

- Add ``"scrapy_zyte_api.ForbiddenDomainSpiderMiddleware"`` to the
`SPIDER_MIDDLEWARES
<https://docs.scrapy.org/en/latest/topics/settings.html#spider-middlewares>`_
Scrapy setting with any value, e.g. ``100``.

Usage
=====

Expand Down
3 changes: 1 addition & 2 deletions scrapy_zyte_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
install_reactor("twisted.internet.asyncioreactor.AsyncioSelectorReactor")

from ._middlewares import (
ForbiddenDomainDownloaderMiddleware,
ForbiddenDomainSpiderMiddleware,
ScrapyZyteAPIDownloaderMiddleware,
ScrapyZyteAPISpiderMiddleware,
)
from ._request_fingerprinter import ScrapyZyteAPIRequestFingerprinter
from .handler import ScrapyZyteAPIDownloadHandler
81 changes: 34 additions & 47 deletions scrapy_zyte_api/_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
logger = logging.getLogger(__name__)


_start_requests_processed = object()


class ScrapyZyteAPIDownloaderMiddleware:
_slot_prefix = "zyte-api@"

Expand All @@ -16,6 +19,8 @@ def from_crawler(cls, crawler):
return cls(crawler)

def __init__(self, crawler) -> None:
self._forbidden_domain_start_request_count = 0
self._total_start_request_count = 0
self._param_parser = _ParamParser(crawler, cookies_enabled=False)
self._crawler = crawler

Expand All @@ -27,6 +32,14 @@ def __init__(self, crawler) -> None:
f"reached."
)

crawler.signals.connect(
self._start_requests_processed, signal=_start_requests_processed
)

def _start_requests_processed(self, count):
self._total_start_request_count = count
self._maybe_close()

def process_request(self, request, spider):
if self._param_parser.parse(request) is None:
return
Expand Down Expand Up @@ -61,51 +74,6 @@ def _max_requests_reached(self, downloader) -> bool:
total_requests = zapi_req_count + download_req_count
return total_requests >= self._max_requests


_start_requests_processed = object()


class ForbiddenDomainSpiderMiddleware:
"""Marks start requests and reports to
:class:`ForbiddenDomainDownloaderMiddleware` the number of them once all
have been processed."""

@classmethod
def from_crawler(cls, crawler):
return cls(crawler)

def __init__(self, crawler):
self._send_signal = crawler.signals.send_catch_log

def process_start_requests(self, start_requests, spider):
count = 0
for request in start_requests:
request.meta["is_start_request"] = True
yield request
count += 1
self._send_signal(_start_requests_processed, count=count)


class ForbiddenDomainDownloaderMiddleware:
"""Closes the spider with ``failed_forbidden_domain`` as close reason if
all start requests get a 451 response from Zyte API."""

@classmethod
def from_crawler(cls, crawler):
return cls(crawler)

def __init__(self, crawler):
self._failed_start_request_count = 0
self._total_start_request_count = 0
crawler.signals.connect(
self._start_requests_processed, signal=_start_requests_processed
)
self._crawler = crawler

def _start_requests_processed(self, count):
self._total_start_request_count = count
self._maybe_close()

def process_exception(self, request, exception, spider):
if (
not request.meta.get("is_start_request")
Expand All @@ -114,13 +82,13 @@ def process_exception(self, request, exception, spider):
):
return

self._failed_start_request_count += 1
self._forbidden_domain_start_request_count += 1
self._maybe_close()

def _maybe_close(self):
if not self._total_start_request_count:
return
if self._failed_start_request_count < self._total_start_request_count:
if self._forbidden_domain_start_request_count < self._total_start_request_count:
return
logger.error(
"Stopping the spider, all start request failed because they "
Expand All @@ -129,3 +97,22 @@ def _maybe_close(self):
self._crawler.engine.close_spider(
self._crawler.spider, "failed_forbidden_domain"
)


class ScrapyZyteAPISpiderMiddleware:
@classmethod
def from_crawler(cls, crawler):
return cls(crawler)

def __init__(self, crawler):
self._send_signal = crawler.signals.send_catch_log

def process_start_requests(self, start_requests, spider):
# Mark start requests and reports to the downloader middleware the
# number of them once all have been processed.
count = 0
for request in start_requests:
request.meta["is_start_request"] = True
yield request
count += 1
self._send_signal(_start_requests_processed, count=count)
6 changes: 6 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@
"http": "scrapy_zyte_api.handler.ScrapyZyteAPIDownloadHandler",
"https": "scrapy_zyte_api.handler.ScrapyZyteAPIDownloadHandler",
},
"DOWNLOADER_MIDDLEWARES": {
"scrapy_zyte_api.ScrapyZyteAPIDownloaderMiddleware": 1000,
},
"REQUEST_FINGERPRINTER_CLASS": "scrapy_zyte_api.ScrapyZyteAPIRequestFingerprinter",
"REQUEST_FINGERPRINTER_IMPLEMENTATION": "2.7", # Silence deprecation warning
"SPIDER_MIDDLEWARES": {
"scrapy_zyte_api.ScrapyZyteAPISpiderMiddleware": 100,
},
"ZYTE_API_KEY": _API_KEY,
"TWISTED_REACTOR": "twisted.internet.asyncioreactor.AsyncioSelectorReactor",
}
Expand Down
127 changes: 127 additions & 0 deletions tests/test_downloader_middleware.py → tests/test_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,130 @@ def parse(self, response):
)
> 0
)


@ensureDeferred
async def test_forbidden_domain_start_url():
class TestSpider(Spider):
name = "test"
start_urls = ["https://forbidden.example"]

def parse(self, response):
pass

settings = {
"ZYTE_API_TRANSPARENT_MODE": True,
**SETTINGS,
}

with MockServer() as server:
settings["ZYTE_API_URL"] = server.urljoin("/")
crawler = get_crawler(TestSpider, settings_dict=settings)
await crawler.crawl()

assert crawler.stats.get_value("finish_reason") == "failed_forbidden_domain"


@ensureDeferred
async def test_forbidden_domain_start_urls():
class TestSpider(Spider):
name = "test"
start_urls = [
"https://forbidden.example",
"https://also-forbidden.example",
"https://oh.definitely-forbidden.example",
]

def parse(self, response):
pass

settings = {
"ZYTE_API_TRANSPARENT_MODE": True,
**SETTINGS,
}

with MockServer() as server:
settings["ZYTE_API_URL"] = server.urljoin("/")
crawler = get_crawler(TestSpider, settings_dict=settings)
await crawler.crawl()

assert crawler.stats.get_value("finish_reason") == "failed_forbidden_domain"


@ensureDeferred
async def test_some_forbidden_domain_start_url():
class TestSpider(Spider):
name = "test"
start_urls = [
"https://forbidden.example",
"https://allowed.example",
]

def parse(self, response):
pass

settings = {
"ZYTE_API_TRANSPARENT_MODE": True,
**SETTINGS,
}

with MockServer() as server:
settings["ZYTE_API_URL"] = server.urljoin("/")
crawler = get_crawler(TestSpider, settings_dict=settings)
await crawler.crawl()

assert crawler.stats.get_value("finish_reason") == "finished"


@ensureDeferred
async def test_follow_up_forbidden_domain_url():
class TestSpider(Spider):
name = "test"
start_urls = [
"https://allowed.example",
]

def parse(self, response):
yield response.follow("https://forbidden.example")

settings = {
"ZYTE_API_TRANSPARENT_MODE": True,
**SETTINGS,
}

with MockServer() as server:
settings["ZYTE_API_URL"] = server.urljoin("/")
crawler = get_crawler(TestSpider, settings_dict=settings)
await crawler.crawl()

assert crawler.stats.get_value("finish_reason") == "finished"


@ensureDeferred
async def test_forbidden_domain_with_partial_start_request_consumption():
"""With concurrency lower than the number of start requests + 1, the code
path followed changes, because ``__total_start_request_count`` is not set
in the downloader middleware until *after* some start requests have been
processed."""

class TestSpider(Spider):
name = "test"
start_urls = [
"https://forbidden.example",
]

def parse(self, response):
yield response.follow("https://forbidden.example")

settings = {
"CONCURRENT_REQUESTS": 1,
"ZYTE_API_TRANSPARENT_MODE": True,
**SETTINGS,
}

with MockServer() as server:
settings["ZYTE_API_URL"] = server.urljoin("/")
crawler = get_crawler(TestSpider, settings_dict=settings)
await crawler.crawl()

assert crawler.stats.get_value("finish_reason") == "failed_forbidden_domain"

0 comments on commit 39cf8a2

Please sign in to comment.