use PyPI as fallback if stable binaries are not on PyTorch indices (#93)

* use PyPI as fallback if stable binaries are not on PyTorch indices

* update readme

* improve readme
This commit is contained in:
Philip Meier
2022-08-17 10:27:59 +02:00
committed by GitHub
parent 6bdc262e3f
commit d848f67e9e
2 changed files with 43 additions and 6 deletions

View File

@@ -34,9 +34,11 @@ index, has some limitations:
hand your NVIDIA driver version simply doesn't support the CUDA version the binary
was compiled with, you can't use any of the GPU features.
To overcome this, PyTorch also hosts _all_ binaries
[themselves](https://download.pytorch.org/whl). To access them, you can still use
`pip install` them, but some
To overcome this, PyTorch also hosts _most_ binaries
[on their own package indices](https://download.pytorch.org/whl). Some distributions are
not compiled against a specific computation backend and thus hosting them on PyPI is
sufficient since they work in every environment. To access PyTorch's package indices,
you can still use `pip install`, but some
[additional options](https://pytorch.org/get-started/locally/) are needed:
```shell

View File

@@ -11,9 +11,10 @@ from typing import List, Set
from unittest import mock
import pip._internal.cli.cmdoptions
import pip._internal.index.collector
import pip._internal.index.package_finder
from pip._internal.index.collector import CollectedSources
from pip._internal.index.package_finder import CandidateEvaluator
from pip._internal.index.sources import build_source
from pip._internal.models.search_scope import SearchScope
import light_the_torch as ltt
@@ -228,7 +229,7 @@ def get_extra_index_urls(computation_backends, channel):
@contextlib.contextmanager
def patch_link_collection(computation_backends, channel):
search_scope = SearchScope.create(
search_scope = SearchScope(
find_links=[], index_urls=get_extra_index_urls(computation_backends, channel)
)
@@ -241,6 +242,39 @@ def patch_link_collection(computation_backends, channel):
with mock.patch.object(input.self, "search_scope", search_scope):
yield
def postprocessing(input, output):
if input.project_name not in PYTORCH_DISTRIBUTIONS:
return output
if channel != Channel.STABLE:
return output
# Some stable binaries are not hosted on the PyTorch indices. We check if this
# is the case for the current distribution.
for remote_file_source in output.index_urls:
candidates = list(remote_file_source.page_candidates())
# Cache the candidates, so `pip` doesn't has to retrieve them again later.
remote_file_source.page_candidates = lambda: iter(candidates)
# If there are any candidates on the PyTorch indices, we continue normally.
if candidates:
return output
# In case the distribution is not present on the PyTorch indices, we fall back
# to PyPI.
_, pypi_file_source = build_source(
SearchScope(
find_links=[], index_urls=["https://pypi.org/simple"]
).get_index_urls_locations(input.project_name)[0],
candidates_from_page=input.candidates_from_page,
page_validator=input.self.session.is_secure_origin,
expand_dir=False,
cache_link_parsing=False,
)
return CollectedSources(find_links=[], index_urls=[pypi_file_source])
with apply_fn_patch(
"pip",
"_internal",
@@ -249,6 +283,7 @@ def patch_link_collection(computation_backends, channel):
"LinkCollector",
"collect_sources",
context=context,
postprocessing=postprocessing,
):
yield