mirror of
https://github.com/pmeier/light-the-torch.git
synced 2024-09-08 23:29:28 +03:00
add support for pytorch-triton (#142)
This commit is contained in:
@@ -45,6 +45,7 @@ PYTORCH_DISTRIBUTIONS = {
|
||||
"torchserve",
|
||||
"torchtext",
|
||||
"torchvision",
|
||||
"pytorch-triton",
|
||||
}
|
||||
|
||||
THIRD_PARTY_PACKAGES = {
|
||||
@@ -374,15 +375,22 @@ def patch_link_collection(computation_backends, channel, user_supplied_pinned_pa
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_candidate_selection(computation_backends):
|
||||
computation_backend_pattern = re.compile(
|
||||
computation_backend_link_pattern = re.compile(
|
||||
r"/(?P<computation_backend>(cpu|cu\d+|rocm([\d.]+)))/"
|
||||
)
|
||||
|
||||
def extract_local_specifier(candidate):
|
||||
local = candidate.version.local
|
||||
|
||||
# Make sure that local actually is a computation backend identifier
|
||||
if local is not None:
|
||||
try:
|
||||
cb.ComputationBackend.from_str(local)
|
||||
except ValueError:
|
||||
local = None
|
||||
|
||||
if local is None:
|
||||
match = computation_backend_pattern.search(candidate.link.path)
|
||||
match = computation_backend_link_pattern.search(candidate.link.comes_from)
|
||||
local = match["computation_backend"] if match else "any"
|
||||
|
||||
# Early PyTorch distributions used the "any" local specifier to indicate a
|
||||
|
||||
@@ -16,7 +16,6 @@ from light_the_torch._patch import (
|
||||
EXCLUDED_PYTORCH_PACKAGES = {
|
||||
"nestedtensor",
|
||||
"pytorch_csprng",
|
||||
"pytorch-triton",
|
||||
"pytorch-triton-rocm",
|
||||
"torch-cuda80",
|
||||
"torch-nightly",
|
||||
|
||||
Reference in New Issue
Block a user