move PyTorch version processing out of sort key (#68)

* move PyTorch version processing out of sort key

* refactor to not change candidate version
This commit is contained in:
Philip Meier
2022-05-04 10:33:55 +02:00
committed by GitHub
parent 453b02bac0
commit 50dbc375b2

View File

@@ -1,9 +1,8 @@
import contextlib
import dataclasses
import enum
import functools
import itertools
import optparse
import re
import sys
@@ -253,34 +252,56 @@ def patch_link_collection(computation_backends, channel):
@contextlib.contextmanager
def patch_candidate_selection(computation_backends):
allowed_locals = {None, *computation_backends}
computation_backend_pattern = re.compile(
r"^/whl/(?P<computation_backend>(cpu|cu\d+|rocm([\d.]+)))/"
r"/(?P<computation_backend>(cpu|cu\d+|rocm([\d.]+)))/"
)
def extract_local_specifier(candidate):
local = candidate.version.local
if local is None:
match = computation_backend_pattern.search(candidate.link.path)
local = match["computation_backend"] if match else "any"
# Early PyTorch distributions used the "any" local specifier to indicate a
# pure Python binary. This was changed to no local specifier later.
# Setting this to "cpu" is technically not correct as it will exclude this
# binary if a non-CPU backend is requested. Still, this is probably the
# right thing to do, since the user requested a specific backend and
# although this binary will work with it, it was not compiled against it.
if local == "any":
local = "cpu"
return local
def preprocessing(input):
candidates = iter(input.candidates)
candidate = next(candidates)
if candidate.name not in PYTORCH_DISTRIBUTIONS:
# At this stage all candidates have the same name. Thus, if the first is
# not a PyTorch distribution, we don't need to check the rest and can
# return without changes.
return
input.candidates = [
candidate
for candidate in input.candidates
if candidate.name not in PYTORCH_DISTRIBUTIONS
or candidate.version.local in allowed_locals
for candidate in itertools.chain([candidate], candidates)
if extract_local_specifier(candidate) in computation_backends
]
sort_key = CandidateEvaluator._sort_key
vanilla_sort_key = CandidateEvaluator._sort_key
def patched_sort_key(candidate_evaluator, candidate):
if candidate.name not in PYTORCH_DISTRIBUTIONS:
return sort_key(candidate_evaluator, candidate)
if candidate.version.local is not None:
computation_backend_str = candidate.version.local.replace("any", "cpu")
else:
match = computation_backend_pattern.match(candidate.link.path)
computation_backend_str = match["computation_backend"] if match else "cpu"
# At this stage all candidates have the same name. Thus, we don't need to
# mirror the exact key structure that the vanilla sort keys have.
return (
cb.ComputationBackend.from_str(computation_backend_str),
candidate.version,
vanilla_sort_key(candidate_evaluator, candidate)
if candidate.name not in PYTORCH_DISTRIBUTIONS
else (
cb.ComputationBackend.from_str(extract_local_specifier(candidate)),
candidate.version.base_version,
)
)
with apply_fn_patch(