add support for pytorch-triton (#142)

This commit is contained in:
Philip Meier
2023-09-01 09:54:50 +02:00
committed by GitHub
parent 5ad5ff345f
commit ff84146751
2 changed files with 10 additions and 3 deletions

View File

@@ -45,6 +45,7 @@ PYTORCH_DISTRIBUTIONS = {
"torchserve",
"torchtext",
"torchvision",
"pytorch-triton",
}
THIRD_PARTY_PACKAGES = {
@@ -374,15 +375,22 @@ def patch_link_collection(computation_backends, channel, user_supplied_pinned_pa
@contextlib.contextmanager
def patch_candidate_selection(computation_backends):
computation_backend_pattern = re.compile(
computation_backend_link_pattern = re.compile(
r"/(?P<computation_backend>(cpu|cu\d+|rocm([\d.]+)))/"
)
def extract_local_specifier(candidate):
local = candidate.version.local
# Make sure that local actually is a computation backend identifier
if local is not None:
try:
cb.ComputationBackend.from_str(local)
except ValueError:
local = None
if local is None:
match = computation_backend_pattern.search(candidate.link.path)
match = computation_backend_link_pattern.search(candidate.link.comes_from)
local = match["computation_backend"] if match else "any"
# Early PyTorch distributions used the "any" local specifier to indicate a

View File

@@ -16,7 +16,6 @@ from light_the_torch._patch import (
EXCLUDED_PYTORCH_PACKAGES = {
"nestedtensor",
"pytorch_csprng",
"pytorch-triton",
"pytorch-triton-rocm",
"torch-cuda80",
"torch-nightly",