mirror of
https://github.com/pmeier/light-the-torch.git
synced 2024-09-08 23:29:28 +03:00
move PyTorch version processing out of sort key (#68)
* move PyTorch version processing out of sort key * refactor to not change candidate version
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import contextlib
|
||||
import dataclasses
|
||||
|
||||
import enum
|
||||
import functools
|
||||
|
||||
import itertools
|
||||
import optparse
|
||||
import re
|
||||
import sys
|
||||
@@ -253,34 +252,56 @@ def patch_link_collection(computation_backends, channel):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_candidate_selection(computation_backends):
|
||||
allowed_locals = {None, *computation_backends}
|
||||
computation_backend_pattern = re.compile(
|
||||
r"^/whl/(?P<computation_backend>(cpu|cu\d+|rocm([\d.]+)))/"
|
||||
r"/(?P<computation_backend>(cpu|cu\d+|rocm([\d.]+)))/"
|
||||
)
|
||||
|
||||
def extract_local_specifier(candidate):
|
||||
local = candidate.version.local
|
||||
|
||||
if local is None:
|
||||
match = computation_backend_pattern.search(candidate.link.path)
|
||||
local = match["computation_backend"] if match else "any"
|
||||
|
||||
# Early PyTorch distributions used the "any" local specifier to indicate a
|
||||
# pure Python binary. This was changed to no local specifier later.
|
||||
# Setting this to "cpu" is technically not correct as it will exclude this
|
||||
# binary if a non-CPU backend is requested. Still, this is probably the
|
||||
# right thing to do, since the user requested a specific backend and
|
||||
# although this binary will work with it, it was not compiled against it.
|
||||
if local == "any":
|
||||
local = "cpu"
|
||||
|
||||
return local
|
||||
|
||||
def preprocessing(input):
|
||||
candidates = iter(input.candidates)
|
||||
candidate = next(candidates)
|
||||
|
||||
if candidate.name not in PYTORCH_DISTRIBUTIONS:
|
||||
# At this stage all candidates have the same name. Thus, if the first is
|
||||
# not a PyTorch distribution, we don't need to check the rest and can
|
||||
# return without changes.
|
||||
return
|
||||
|
||||
input.candidates = [
|
||||
candidate
|
||||
for candidate in input.candidates
|
||||
if candidate.name not in PYTORCH_DISTRIBUTIONS
|
||||
or candidate.version.local in allowed_locals
|
||||
for candidate in itertools.chain([candidate], candidates)
|
||||
if extract_local_specifier(candidate) in computation_backends
|
||||
]
|
||||
|
||||
sort_key = CandidateEvaluator._sort_key
|
||||
vanilla_sort_key = CandidateEvaluator._sort_key
|
||||
|
||||
def patched_sort_key(candidate_evaluator, candidate):
|
||||
if candidate.name not in PYTORCH_DISTRIBUTIONS:
|
||||
return sort_key(candidate_evaluator, candidate)
|
||||
|
||||
if candidate.version.local is not None:
|
||||
computation_backend_str = candidate.version.local.replace("any", "cpu")
|
||||
else:
|
||||
match = computation_backend_pattern.match(candidate.link.path)
|
||||
computation_backend_str = match["computation_backend"] if match else "cpu"
|
||||
|
||||
# At this stage all candidates have the same name. Thus, we don't need to
|
||||
# mirror the exact key structure that the vanilla sort keys have.
|
||||
return (
|
||||
cb.ComputationBackend.from_str(computation_backend_str),
|
||||
candidate.version,
|
||||
vanilla_sort_key(candidate_evaluator, candidate)
|
||||
if candidate.name not in PYTORCH_DISTRIBUTIONS
|
||||
else (
|
||||
cb.ComputationBackend.from_str(extract_local_specifier(candidate)),
|
||||
candidate.version.base_version,
|
||||
)
|
||||
)
|
||||
|
||||
with apply_fn_patch(
|
||||
|
||||
Reference in New Issue
Block a user