mirror of
https://github.com/pmeier/light-the-torch.git
synced 2024-09-08 23:29:28 +03:00
rewrite project (#60)
* replace tox with doit (#42)
* replace tox with doit
* fix dev env setup
* fix dev env setup
* fix dev env setup
* apply fixes to all workflows
* try fix path
* try pip cache
* fix output location
* fix quotes
* fix quotes
* move path file extraction to script
* setup cache
* remove debug step
* add pre-commit caching
* rewrite ltt to wrap pip directly rather than use its internals
* add publish task and fix workflow
* fix publish
* update pre-commit hooks
* remove mypy
* cleanup
* add setup task
* fix some bugs
* disable output capturing on publish task
* fix publish task
* install wheel package in CI dev setup
* Rewrite smoke and computation backend tests (#44)
* fix publish task
* install wheel package in CI dev setup
* fix computation backend tests
* fix smoke tests
* delete other tests for now
* fix test workflow
* add CLI tests after rewrite (#46)
* rewrite README (#45)
* rewrite README
* update how does it work
* refactor the why section
* fix typo
Co-authored-by: James Butler <jamesobutler@users.noreply.github.com>
Co-authored-by: James Butler <jamesobutler@users.noreply.github.com>
* try appdirs to find pip cache path (#48)
* try appdirs to find pip cache path
* print
* fix app_author
* cleanup
* fix
* cleanup
* more cleanup
* refactor dodo (#47)
* refactor dodo
* fix test CI
* address review comments
Co-authored-by: Tony Fast <tony.fast@gmail.com>
Co-authored-by: Tony Fast <tony.fast@gmail.com>
* add default tasks (#49)
* fix module entrypoint (#50)
* fix candidate selection (#52)
* remove test task passthrough (#53)
* disable ROCm wheels (#54)
* fix ROCm deselection (#56)
* relax ROCm deselection even further (#57)
* add naive torch install test (#58)
* add naive torch install test
* trigger CI
* add cpuonly check
* don't fail fast
* add pytorch channel to the test matrix
* don't test LTS channel on 3.10
* add check without specifying the channel
* Revert "add check without specifying the channel"
This reverts commit 0842abf50f.
* use extra index rathe than find links for link patching (#59)
* update README
* fix Windows install
Co-authored-by: James Butler <jamesobutler@users.noreply.github.com>
Co-authored-by: Tony Fast <tony.fast@gmail.com>
This commit is contained in:
27
.flake8
27
.flake8
@@ -1,22 +1,19 @@
|
||||
[flake8]
|
||||
# See link below for available options
|
||||
# https://flake8.pycqa.org/en/latest/user/options.html#options-and-their-descriptions
|
||||
# Move this to pyproject.toml as soon as it is supported.
|
||||
# See https://gitlab.com/pycqa/flake8/issues/428
|
||||
# See https://flake8.pycqa.org/en/latest/user/options.html#options-and-their-descriptions
|
||||
# for available options
|
||||
|
||||
exclude =
|
||||
.git,
|
||||
.venv,
|
||||
.eggs,
|
||||
.mypy_cache,
|
||||
.pytest_cache,
|
||||
.tox,
|
||||
__pycache__,
|
||||
**/__pycache__,
|
||||
*.pyc,
|
||||
ignore = E203, E501, W503
|
||||
max-line-length = 88
|
||||
.venv,
|
||||
build,
|
||||
|
||||
# See https://www.flake8rules.com/ for a list of all builtin error codes
|
||||
ignore =
|
||||
E203, E501, W503
|
||||
per-file-ignores =
|
||||
__init__.py: F401, F403, F405
|
||||
conftest.py: F401, F403, F405
|
||||
__init__.py: F401
|
||||
conftest.py: F401
|
||||
|
||||
show_source = True
|
||||
statistics = True
|
||||
|
||||
46
.github/actions/setup-dev-env/action.yml
vendored
Normal file
46
.github/actions/setup-dev-env/action.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
name: setup-dev-env
|
||||
description: "Setup development environment"
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
default: "3.7"
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
|
||||
steps:
|
||||
- name: Set up python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Get pip cache path
|
||||
id: get-cache-path
|
||||
shell: bash
|
||||
run: |
|
||||
pip install appdirs
|
||||
CACHE_PATH=`python -c "import appdirs; print(appdirs.user_cache_dir('pip', appauthor=False))"`
|
||||
echo "::set-output name=path::$CACHE_PATH"
|
||||
|
||||
- name: Restore pip cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ${{ steps.get-cache-path.outputs.path }}
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Upgrade or install system packages
|
||||
shell: bash
|
||||
run: python -m pip install --upgrade pip wheel
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install development requirements and project
|
||||
shell: bash
|
||||
run: |
|
||||
pip install doit
|
||||
doit install
|
||||
55
.github/workflows/install.yml
vendored
Normal file
55
.github/workflows/install.yml
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
name: install
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- releases/*
|
||||
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
torch_cpu:
|
||||
strategy:
|
||||
matrix:
|
||||
os:
|
||||
- ubuntu-latest
|
||||
- windows-latest
|
||||
- macos-latest
|
||||
python-version:
|
||||
- "3.7"
|
||||
- "3.8"
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
pytorch-channel:
|
||||
- stable
|
||||
- test
|
||||
- nightly
|
||||
- lts
|
||||
exclude:
|
||||
- os: macos-latest
|
||||
pytorch-channel: lts
|
||||
- python-version: "3.10"
|
||||
pytorch-channel: lts
|
||||
fail-fast: false
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup development environment
|
||||
uses: ./.github/actions/setup-dev-env
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install torch
|
||||
run: ltt install --cpuonly --pytorch-channel=${{ matrix.pytorch-channel }} torch
|
||||
|
||||
- name: Check if CPU only
|
||||
run:
|
||||
python -c "import sys, torch; sys.exit(hasattr(torch._C,
|
||||
'_cuda_getDeviceCount'))"
|
||||
34
.github/workflows/lint.yml
vendored
34
.github/workflows/lint.yml
vendored
@@ -7,39 +7,27 @@ on:
|
||||
- releases/*
|
||||
|
||||
pull_request:
|
||||
paths:
|
||||
- "**.py"
|
||||
- "pyproject.toml"
|
||||
- ".pre-commit-config.yaml"
|
||||
- ".flake8"
|
||||
- "mypy.ini"
|
||||
- "tox.ini"
|
||||
- "requirements-dev.txt"
|
||||
- ".github/workflows/lint.yml"
|
||||
|
||||
jobs:
|
||||
check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Set up python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.6"
|
||||
|
||||
- name: Upgrade pip
|
||||
run: python -m pip install --upgrade pip
|
||||
|
||||
- name: Upgrade or install additional system packages
|
||||
run: pip install --upgrade setuptools virtualenv wheel
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install dev requirements
|
||||
run: pip install -r requirements-dev.txt
|
||||
- name: Setup development environment
|
||||
uses: ./.github/actions/setup-dev-env
|
||||
|
||||
- name: Restore pre-commit cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
restore-keys: |
|
||||
pre-commit-
|
||||
|
||||
- name: Run lint
|
||||
run: tox -e lint
|
||||
run: doit lint
|
||||
|
||||
27
.github/workflows/publish.yml
vendored
27
.github/workflows/publish.yml
vendored
@@ -2,38 +2,27 @@ name: publish
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [created]
|
||||
types:
|
||||
- published
|
||||
|
||||
jobs:
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Set up python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.6"
|
||||
|
||||
- name: Upgrade pip
|
||||
run: python -m pip install --upgrade pip
|
||||
|
||||
- name: Upgrade or install additional system packages
|
||||
run: pip install --upgrade setuptools virtualenv wheel
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install pep517 and twine
|
||||
run: pip install pep517 twine
|
||||
- name: Setup development environment
|
||||
uses: ./.github/actions/setup-dev-env
|
||||
with:
|
||||
python-version: "3.7"
|
||||
|
||||
- name: Build source and binary
|
||||
run: python -m pep517.build --source --binary .
|
||||
|
||||
- name: Upload to PyPI
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_REPOSITORY: pypi
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||
run: twine upload dist/*
|
||||
run: doit publish
|
||||
|
||||
28
.github/workflows/publishable.yml
vendored
28
.github/workflows/publishable.yml
vendored
@@ -8,41 +8,31 @@ on:
|
||||
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/publishable.yml"
|
||||
- ".github/actions/setup-dev-env/**"
|
||||
- "light_the_torch/**"
|
||||
- "pyproject.toml"
|
||||
- "setup.cfg"
|
||||
- ".gitignore"
|
||||
- "CONTRIBUTING.rst"
|
||||
- "dodo.py"
|
||||
- "LICENSE"
|
||||
- "MANIFEST.in"
|
||||
- "pyproject.toml"
|
||||
- "README.rst"
|
||||
- "tox.ini"
|
||||
- "requirements-dev.txt"
|
||||
- ".github/workflows/publishable.yml"
|
||||
- "setup.cfg"
|
||||
|
||||
jobs:
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Set up python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.6"
|
||||
|
||||
- name: Upgrade pip
|
||||
run: python -m pip install --upgrade pip
|
||||
|
||||
- name: Upgrade or install additional system packages
|
||||
run: pip install --upgrade setuptools virtualenv wheel
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install dev requirements
|
||||
run: pip install -r requirements-dev.txt
|
||||
- name: Setup development environment
|
||||
uses: ./.github/actions/setup-dev-env
|
||||
|
||||
- name: Run unit tests
|
||||
run: tox -e publishable
|
||||
- name: Check if publishable
|
||||
run: doit publishable
|
||||
|
||||
69
.github/workflows/tests.yml
vendored
69
.github/workflows/tests.yml
vendored
@@ -8,14 +8,16 @@ on:
|
||||
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/tests.yml"
|
||||
- ".github/actions/setup-dev-env/**"
|
||||
- "light_the_torch/**"
|
||||
- "tests/**"
|
||||
- "pytest.ini"
|
||||
- "tox.ini"
|
||||
- ".coveragerc"
|
||||
- "codecov.yml"
|
||||
- "dodo.py"
|
||||
- "pytest.ini"
|
||||
- "requirements-dev.txt"
|
||||
- ".github/workflows/tests.yml"
|
||||
- "setup.cfg"
|
||||
|
||||
schedule:
|
||||
- cron: "0 4 * * *"
|
||||
@@ -24,62 +26,37 @@ jobs:
|
||||
unit:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
python: ['3.6', '3.7', '3.8']
|
||||
fail-fast: true
|
||||
os:
|
||||
- ubuntu-latest
|
||||
- windows-latest
|
||||
- macos-latest
|
||||
python-version:
|
||||
- "3.7"
|
||||
- "3.8"
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
OS: ${{ matrix.os }}
|
||||
PYTHON: ${{ matrix.python }}
|
||||
PYTHON_VERSION: ${{ matrix.python-version }}
|
||||
|
||||
steps:
|
||||
- name: Set up python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Upgrade pip
|
||||
run: python -m pip install --upgrade pip
|
||||
|
||||
- name: Upgrade or install additional system packages
|
||||
run: pip install --upgrade setuptools virtualenv wheel
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install dev requirements
|
||||
run: pip install -r requirements-dev.txt
|
||||
- name: Setup development environment
|
||||
uses: ./.github/actions/setup-dev-env
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Run unit tests
|
||||
run: tox -e py -- --skip-large-download
|
||||
run: doit test
|
||||
|
||||
- name: Upload coverage
|
||||
uses: codecov/codecov-action@v1.0.7
|
||||
uses: codecov/codecov-action@v2.1.0
|
||||
with:
|
||||
env_vars: OS,PYTHON
|
||||
|
||||
cli:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Set up python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.6"
|
||||
|
||||
- name: Upgrade and install additional system packages
|
||||
run: pip install --upgrade pip setuptools virtualenv wheel
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Install project
|
||||
run: pip install .
|
||||
|
||||
- name: Test CLI
|
||||
run: |
|
||||
ltt --version
|
||||
ltt --help
|
||||
flags: unit
|
||||
env_vars: OS,PYTHON_VERSION
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,5 +1,7 @@
|
||||
light_the_torch/_version.py
|
||||
|
||||
.doit.db*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
repos:
|
||||
- repo: https://github.com/timothycrosley/isort
|
||||
rev: "4.3.21"
|
||||
hooks:
|
||||
- id: isort
|
||||
args: [--settings-path=pyproject.toml, --filter-files]
|
||||
additional_dependencies: [toml]
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 19.10b0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [--config=pyproject.toml]
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.1.0
|
||||
rev: v4.1.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
- id: check-docstring-first
|
||||
- id: check-toml
|
||||
- id: check-yaml
|
||||
- id: trailing-whitespace
|
||||
- id: mixed-line-ending
|
||||
args:
|
||||
- --fix=lf
|
||||
- id: end-of-file-fixer
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v2.5.1
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or:
|
||||
- markdown
|
||||
- toml
|
||||
- yaml
|
||||
|
||||
- repo: https://github.com/omnilib/ufmt
|
||||
rev: v1.3.2
|
||||
hooks:
|
||||
- id: ufmt
|
||||
additional_dependencies:
|
||||
- black == 22.1.0
|
||||
- usort == 1.0.2
|
||||
|
||||
2
.prettierrc.yaml
Normal file
2
.prettierrc.yaml
Normal file
@@ -0,0 +1,2 @@
|
||||
proseWrap: always
|
||||
printWidth: 88
|
||||
@@ -58,7 +58,7 @@ To run the full lint check locally run
|
||||
Tests
|
||||
-----
|
||||
|
||||
``light-the-torch`` uses `pytest <https://docs.pytest.org/en/stable/>`_ to run the test
|
||||
``light-the-torch`` uses `pytest <https://docs.pytest.org/en/stable/>`_ to run the test
|
||||
suite. You can run it locally with
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
@@ -6,8 +6,8 @@ exclude .flake8
|
||||
exclude .gitignore
|
||||
exclude .pre-commit-config.yaml
|
||||
exclude codecov.yml
|
||||
exclude dodo.py
|
||||
exclude MANIFEST.in
|
||||
exclude mypy.ini
|
||||
exclude pytest.ini
|
||||
exclude requirements-dev.txt
|
||||
exclude tox.ini
|
||||
|
||||
126
README.md
Normal file
126
README.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# `light-the-torch`
|
||||
|
||||
[](https://opensource.org/licenses/BSD-3-Clause)
|
||||
[](https://www.repostatus.org/#wip)
|
||||
[](https://codecov.io/gh/pmeier/light-the-torch)
|
||||
|
||||
`light-the-torch` is a small utility that wraps `pip` to ease the installation process
|
||||
for PyTorch distributions and third-party packages that depend on them. It auto-detects
|
||||
compatible CUDA versions from the local setup and installs the correct PyTorch binaries
|
||||
without user interference.
|
||||
|
||||
- [Why do I need it?](#why-do-i-need-it)
|
||||
- [How do I install it?](#how-do-i-install-it)
|
||||
- [How do I use it?](#how-do-i-use-it)
|
||||
- [How does it work?](#how-does-it-work)
|
||||
|
||||
## Why do I need it?
|
||||
|
||||
PyTorch distributions are fully `pip install`'able, but PyPI, the default `pip` search
|
||||
index, has some limitations:
|
||||
|
||||
1. PyPI regularly only allows binaries up to a size of
|
||||
[approximately 60 MB](https://github.com/pypa/packaging-problems/issues/86). One can
|
||||
[request a file size limit increase](https://pypi.org/help/#file-size-limit) (and the
|
||||
PyTorch team probably does that for every release), but it is still not enough:
|
||||
although PyTorch has pre-built binaries for Windows with CUDA, they cannot be
|
||||
installed through PyPI due to their size.
|
||||
2. PyTorch uses local version specifiers to indicate for which computation backend the
|
||||
binary was compiled, for example `torch==1.11.0+cpu`. Unfortunately, local specifiers
|
||||
are not allowed on PyPI. Thus, only the binaries compiled with one CUDA version are
|
||||
uploaded without an indication of the CUDA version. If you do not have a CUDA capable
|
||||
GPU, downloading this is only a waste of bandwidth and disk capacity. If on the other
|
||||
hand your NVIDIA driver version simply doesn't support the CUDA version the binary
|
||||
was compiled with, you can't use any of the GPU features.
|
||||
|
||||
To overcome this, PyTorch also hosts _all_ binaries
|
||||
[themselves](https://download.pytorch.org/whl/torch_stable.html). To access them, you
|
||||
can still use `pip install` them, but some
|
||||
[additional options](https://pytorch.org/get-started/locally/) are needed:
|
||||
|
||||
```shell
|
||||
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
```
|
||||
|
||||
While this is certainly an improvement, it still has a few downsides:
|
||||
|
||||
1. You need to know what computation backend, e.g. CUDA 11.3 (`cu113`), is supported on
|
||||
your local machine. This can be quite challenging for new users and at least tedious
|
||||
for more experienced ones.
|
||||
2. Besides the stable binaries, PyTorch also offers nightly, test, and long-time support
|
||||
(LTS) ones. To install them, you need a different `--extra-index-url` value for each.
|
||||
For the nightly and test channel you also need to supply the `--pre` option.
|
||||
3. If you want to install any package hosted on PyPI that depends on PyTorch, you always
|
||||
also specify all PyTorch distributions to install. Otherwise, the `--extra-index-url`
|
||||
flag is ignored and the PyTorch distributions hosted on PyPI will be installed.
|
||||
|
||||
If any of these points don't sound appealing to you, and you just want to have the same
|
||||
user experience as `pip install` for PyTorch distributions, `light-the-torch` was made
|
||||
for you.
|
||||
|
||||
## How do I install it?
|
||||
|
||||
Installing `light-the-torch` is as easy as
|
||||
|
||||
```shell
|
||||
pip install --pre light-the-torch
|
||||
```
|
||||
|
||||
Since it depends on `pip` and it might be upgraded during installation,
|
||||
[Windows users](https://pip.pypa.io/en/stable/installation/#upgrading-pip) should
|
||||
install it with
|
||||
|
||||
```shell
|
||||
py -m pip install --pre light-the-torch
|
||||
```
|
||||
|
||||
## How do I use it?
|
||||
|
||||
After `light-the-torch` is installed you can use its CLI interface `ltt` as drop-in
|
||||
replacement for `pip`:
|
||||
|
||||
```shell
|
||||
ltt install torch
|
||||
```
|
||||
|
||||
In fact, `ltt` is `pip` with a few added options:
|
||||
|
||||
- By default, `ltt` uses the local NVIDIA driver version to select the correct binary
|
||||
for you. You can pass the `--pytorch-computation-backend` option to manually specify
|
||||
the computation backend you want to use:
|
||||
|
||||
```shell
|
||||
ltt install --pytorch-computation-backend=cu102 torch
|
||||
```
|
||||
|
||||
- By default, `ltt` installs stable PyTorch binaries. To install binaries from nightly,
|
||||
test, or LTS channels pass the `--pytorch-channel` option:
|
||||
|
||||
```shell
|
||||
ltt install --pytorch-channel=nightly torch
|
||||
```
|
||||
|
||||
If `--pytorch-channel` is not passed, using `pip`'s builtin `--pre` option will
|
||||
install PyTorch test binaries.
|
||||
|
||||
Of course you are not limited to install only PyTorch distributions. Everything shown
|
||||
above also works if you install packages that depend on PyTorch:
|
||||
|
||||
```shell
|
||||
ltt install --pytorch-computation-backend=cpu --pytorch-channel=nightly pystiche
|
||||
```
|
||||
|
||||
## How does it work?
|
||||
|
||||
The authors of `pip` **do not condone** the use of `pip` internals as they might break
|
||||
without warning. As a results of this, `pip` has no capability for plugins to hook into
|
||||
specific tasks.
|
||||
|
||||
`light-the-torch` works by monkey-patching `pip` internals at runtime:
|
||||
|
||||
- While searching for a download link for a PyTorch distribution, `light-the-torch`
|
||||
replaces the default search index with an official PyTorch download link. This is
|
||||
equivalent to calling `pip install` with the `--extra-index-url` option only for
|
||||
PyTorch distributions.
|
||||
- While evaluating possible PyTorch installation candidates, `light-the-torch` culls
|
||||
binaries incompatible with the hardware.
|
||||
333
README.rst
333
README.rst
@@ -1,333 +0,0 @@
|
||||
light-the-torch
|
||||
===============
|
||||
|
||||
.. start-badges
|
||||
|
||||
.. list-table::
|
||||
:stub-columns: 1
|
||||
|
||||
* - package
|
||||
- |license| |status|
|
||||
* - code
|
||||
- |black| |mypy| |lint|
|
||||
* - tests
|
||||
- |tests| |coverage|
|
||||
|
||||
.. end-badges
|
||||
|
||||
``light-the-torch`` offers a small CLI (and
|
||||
`tox plugin <https://github.com/pmeier/tox-ltt>`_) based on ``pip`` to install PyTorch
|
||||
distributions from the stable releases. Similar to the platform and Python version, the
|
||||
computation backend is auto-detected from the available hardware preferring CUDA over
|
||||
CPU.
|
||||
|
||||
Motivation
|
||||
==========
|
||||
|
||||
With each release of a PyTorch distribution (``torch``, ``torchvision``,
|
||||
``torchaudio``, ``torchtext``) the wheels are published for combinations of different
|
||||
computation backends (CPU, CUDA), platforms, and Python versions. Unfortunately, a
|
||||
differentation based on the computation backend is not supported by
|
||||
`PEP 440 <https://www.python.org/dev/peps/pep-0440/>`_ . As a workaround the
|
||||
computation backend is added as a local specifier. For example
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
torch==1.5.1+cpu
|
||||
|
||||
Due to this restriction only the wheels of the latest CUDA release are uploaded to
|
||||
`PyPI <https://pypi.org/search/?q=torch>`_ and thus easily ``pip install`` able. For
|
||||
other CUDA versions or the installation without CUDA support, one has to resort to
|
||||
manual version specification:
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
pip install -f https://download.pytorch.org/whl/torch_stable.html torch==1.5.1+cu101
|
||||
|
||||
This is especially frustrating if one wants to install packages that depend on one or
|
||||
several PyTorch distributions: for each package the required PyTorch distributions have
|
||||
to be manually tracked down, resolved, and installed before the other requirements can
|
||||
be installed.
|
||||
|
||||
``light-the-torch`` was developed to overcome this.
|
||||
|
||||
Installation
|
||||
============
|
||||
|
||||
The latest **published** version can be installed with
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
pip install light-the-torch
|
||||
|
||||
|
||||
The latest, potentially unstable **development** version can be installed with
|
||||
|
||||
.. code-block::
|
||||
|
||||
pip install git+https://github.com/pmeier/light-the-torch
|
||||
|
||||
Usage
|
||||
=====
|
||||
|
||||
.. note::
|
||||
|
||||
The following examples were run on a linux machine with Python 3.6 and CUDA 10.1. The
|
||||
distributions hosted on PyPI were built with CUDA 10.2.
|
||||
|
||||
CLI
|
||||
---
|
||||
|
||||
The CLI of ``light-the-torch`` is invoked with its shorthand ``ltt``
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
$ ltt --help
|
||||
usage: ltt [-h] [-V] {install,extract,find} ...
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
-V, --version show light-the-torch version and path and exit
|
||||
|
||||
subcommands:
|
||||
{install,extract,find}
|
||||
|
||||
``ltt install``
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
$ ltt install --help
|
||||
usage: ltt install [-h] [--force-cpu] [--pytorch-only]
|
||||
[--install-cmd INSTALL_CMD] [--verbose]
|
||||
[args [args ...]]
|
||||
|
||||
Install PyTorch distributions from the stable releases. The computation
|
||||
backend is auto-detected from the available hardware preferring CUDA over CPU.
|
||||
|
||||
positional arguments:
|
||||
args arguments of 'pip install'. Optional arguments have to
|
||||
be seperated by '--'
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--force-cpu disable computation backend auto-detection and use CPU
|
||||
instead
|
||||
--pytorch-only install only PyTorch distributions
|
||||
--install-cmd INSTALL_CMD
|
||||
installation command for the PyTorch distributions and
|
||||
additional packages. Defaults to 'python -m pip
|
||||
install {packages}'
|
||||
--verbose print more output to STDOUT. For fine control use -v /
|
||||
--verbose and -q / --quiet of the 'pip install'
|
||||
options
|
||||
|
||||
``ltt install`` is a drop-in replacement for ``pip install`` without worrying about the
|
||||
computation backend:
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
$ ltt install torch torchvision
|
||||
[...]
|
||||
Successfully installed future-0.18.2 numpy-1.19.0 pillow-7.2.0 torch-1.5.1+cu101 torchvision-0.6.1+cu101
|
||||
[...]
|
||||
|
||||
|
||||
``ltt install`` is also able to handle packages that depend on PyTorch distributions:
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
$ ltt install kornia
|
||||
[...]
|
||||
Successfully installed future-0.18.2 numpy-1.19.0 torch-1.5.0+cu101
|
||||
[...]
|
||||
Successfully installed kornia-0.3.1
|
||||
|
||||
``ltt extract``
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
$ ltt extract --help
|
||||
usage: ltt extract [-h] [--verbose] [args [args ...]]
|
||||
|
||||
Extract required PyTorch distributions
|
||||
|
||||
positional arguments:
|
||||
args arguments of 'pip install'. Optional arguments have to be
|
||||
seperated by '--'
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--verbose print more output to STDOUT. For fine control use -v / --verbose
|
||||
and -q / --quiet of the 'pip install' options
|
||||
|
||||
|
||||
``ltt extract`` extracts the required PyTorch distributions out of packages:
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
$ ltt extract kornia
|
||||
torch==1.5.0
|
||||
|
||||
.. warning::
|
||||
|
||||
Internally, ``light-the-torch`` uses the ``pip`` resolver which, as of now,
|
||||
unfortunately allows conflicting dependencies:
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
$ ltt extract kornia "torch>1.5"
|
||||
torch>1.5
|
||||
|
||||
``ltt find``
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
$ ltt find --help
|
||||
usage: ltt find [-h] [--computation-backend COMPUTATION_BACKEND]
|
||||
[--platform PLATFORM] [--python-version PYTHON_VERSION]
|
||||
[--verbose]
|
||||
[args [args ...]]
|
||||
|
||||
Find wheel links for the required PyTorch distributions
|
||||
|
||||
positional arguments:
|
||||
args arguments of 'pip install'. Optional arguments have to
|
||||
be seperated by '--'
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--computation-backend COMPUTATION_BACKEND
|
||||
Only use wheels compatible with COMPUTATION_BACKEND,
|
||||
for example 'cu102' or 'cpu'. Defaults to the
|
||||
computation backend of the running system, preferring
|
||||
CUDA over CPU.
|
||||
--platform PLATFORM Only use wheels compatible with <platform>. Defaults
|
||||
to the platform of the running system.
|
||||
--python-version PYTHON_VERSION
|
||||
The Python interpreter version to use for wheel and
|
||||
"Requires-Python" compatibility checks. Defaults to a
|
||||
version derived from the running interpreter. The
|
||||
version can be specified using up to three dot-
|
||||
separated integers (e.g. "3" for 3.0.0, "3.7" for
|
||||
3.7.0, or "3.7.3"). A major-minor version can also be
|
||||
given as a string without dots (e.g. "37" for 3.7.0).
|
||||
--verbose print more output to STDOUT. For fine control use -v /
|
||||
--verbose and -q / --quiet of the 'pip install'
|
||||
options
|
||||
|
||||
``ltt find`` finds the links to the wheels of the required PyTorch distributions:
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
$ ltt find torchaudio > requirements.txt
|
||||
$ cat requirements.txt
|
||||
https://download.pytorch.org/whl/cu101/torch-1.5.1%2Bcu101-cp36-cp36m-linux_x86_64.whl
|
||||
https://download.pytorch.org/whl/torchaudio-0.5.1-cp36-cp36m-linux_x86_64.whl
|
||||
|
||||
The ``--computation-backend``, ``--platform``, and ``python-version`` options can be
|
||||
used pin wheel properties instead of auto-detecting them:
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
$ ltt find \
|
||||
--computation-backend cu92 \
|
||||
--platform win_amd64 \
|
||||
--python-version 3.7 \
|
||||
torchtext
|
||||
https://download.pytorch.org/whl/cu92/torch-1.5.1%2Bcu92-cp37-cp37m-win_amd64.whl
|
||||
https://download.pytorch.org/whl/torchtext-0.6.0-py3-none-any.whl
|
||||
|
||||
Python
|
||||
------
|
||||
|
||||
``light-the-torch`` exposes two functions that can be used from Python:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import light_the_torch as ltt
|
||||
help(ltt.extract_dists)
|
||||
|
||||
.. code-block::
|
||||
|
||||
Help on function extract_dists in module light_the_torch._pip.extract:
|
||||
|
||||
extract_dists(pip_install_args:List[str], verbose:bool=False) -> List[str]
|
||||
Extract direct or indirect required PyTorch distributions.
|
||||
|
||||
Args:
|
||||
pip_install_args: Arguments passed to ``pip install`` that will be searched for
|
||||
required PyTorch distributions
|
||||
verbose: If ``True``, print additional information to STDOUT.
|
||||
|
||||
Returns:
|
||||
Resolved required PyTorch distributions.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import light_the_torch as ltt
|
||||
help(ltt.find_links)
|
||||
|
||||
.. code-block::
|
||||
|
||||
Help on function find_links in module light_the_torch._pip.find:
|
||||
|
||||
find_links(pip_install_args:List[str], computation_backend:Union[str, light_the_torch.computation_backend.ComputationBackend, NoneType]=None, platform:Union[str, NoneType]=None, python_version:Union[str, NoneType]=None, verbose:bool=False) -> List[str]
|
||||
Find wheel links for direct or indirect PyTorch distributions with given
|
||||
properties.
|
||||
|
||||
Args:
|
||||
pip_install_args: Arguments passed to ``pip install`` that will be searched for
|
||||
required PyTorch distributions
|
||||
computation_backend: Computation backend, for example ``"cpu"`` or ``"cu102"``.
|
||||
Defaults to the available hardware of the running system preferring CUDA
|
||||
over CPU.
|
||||
platform: Platform, for example ``"linux_x86_64"`` or ``"win_amd64"``. Defaults
|
||||
to the platform of the running system.
|
||||
python_version: Python version, for example ``"3"`` or ``"3.7"``. Defaults to
|
||||
the version of the running interpreter.
|
||||
verbose: If ``True``, print additional information to STDOUT.
|
||||
|
||||
Returns:
|
||||
Wheel links with given properties for all required PyTorch distributions.
|
||||
|
||||
.. note::
|
||||
|
||||
Optional arguments for ``pip install`` have to be passed after a ``--`` seperator.
|
||||
|
||||
.. |license|
|
||||
image:: https://img.shields.io/badge/License-BSD%203--Clause-blue.svg
|
||||
:target: https://opensource.org/licenses/BSD-3-Clause
|
||||
:alt: License
|
||||
|
||||
.. |status|
|
||||
image:: https://www.repostatus.org/badges/latest/wip.svg
|
||||
:alt: Project Status: WIP
|
||||
:target: https://www.repostatus.org/#wip
|
||||
|
||||
.. |black|
|
||||
image:: https://img.shields.io/badge/code%20style-black-000000.svg
|
||||
:target: https://github.com/psf/black
|
||||
:alt: black
|
||||
|
||||
.. |mypy|
|
||||
image:: http://www.mypy-lang.org/static/mypy_badge.svg
|
||||
:target: http://mypy-lang.org/
|
||||
:alt: mypy
|
||||
|
||||
.. |lint|
|
||||
image:: https://github.com/pmeier/light-the-torch/workflows/lint/badge.svg
|
||||
:target: https://github.com/pmeier/light-the-torch/actions?query=workflow%3Alint+branch%3Amain
|
||||
:alt: Lint status via GitHub Actions
|
||||
|
||||
.. |tests|
|
||||
image:: https://github.com/pmeier/light-the-torch/workflows/tests/badge.svg
|
||||
:target: https://github.com/pmeier/light-the-torch/actions?query=workflow%3Atests+branch%3Amain
|
||||
:alt: Test status via GitHub Actions
|
||||
|
||||
.. |coverage|
|
||||
image:: https://codecov.io/gh/pmeier/light-the-torch/branch/main/graph/badge.svg
|
||||
:target: https://codecov.io/gh/pmeier/light-the-torch
|
||||
:alt: Test coverage via codecov.io
|
||||
144
dodo.py
Normal file
144
dodo.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import os
|
||||
import pathlib
|
||||
import shlex
|
||||
|
||||
from doit.action import CmdAction
|
||||
|
||||
|
||||
HERE = pathlib.Path(__file__).parent
|
||||
PACKAGE_NAME = "light_the_torch"
|
||||
|
||||
CI = os.environ.get("CI") == "1"
|
||||
|
||||
DOIT_CONFIG = dict(
|
||||
verbosity=2,
|
||||
backend="json",
|
||||
default_tasks=[
|
||||
"lint",
|
||||
"test",
|
||||
"publishable",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def do(*cmd):
|
||||
if len(cmd) == 1:
|
||||
cmd = cmd[0]
|
||||
if isinstance(cmd, str):
|
||||
cmd = shlex.split(cmd)
|
||||
return CmdAction(cmd, shell=False, cwd=HERE)
|
||||
|
||||
|
||||
def _install_dev_requirements(pip="python -m pip"):
|
||||
return f"{pip} install -r requirements-dev.txt"
|
||||
|
||||
|
||||
def _install_project(pip="python -m pip"):
|
||||
return f"{pip} install -e ."
|
||||
|
||||
|
||||
def task_install():
|
||||
"""Installs all development requirements and light-the-torch in development mode"""
|
||||
yield dict(
|
||||
name="dev",
|
||||
file_dep=[HERE / "requirements-dev.txt"],
|
||||
actions=[do(_install_dev_requirements())],
|
||||
)
|
||||
yield dict(
|
||||
name="project",
|
||||
actions=[
|
||||
do(_install_project()),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def task_setup():
|
||||
"""Sets up a development environment for light-the-torch"""
|
||||
dev_env = HERE / ".venv"
|
||||
pip = dev_env / "bin" / "pip"
|
||||
return dict(
|
||||
actions=[
|
||||
do(f"virtualenv {dev_env} --prompt='(light-the-torch-dev) '"),
|
||||
do(_install_dev_requirements(pip)),
|
||||
do(_install_project(pip)),
|
||||
lambda: print(
|
||||
f"run `source {dev_env / 'bin' / 'activate'}` the virtual environment"
|
||||
),
|
||||
],
|
||||
clean=[do(f"rm -rf {dev_env}")],
|
||||
uptodate=[lambda: dev_env.exists()],
|
||||
)
|
||||
|
||||
|
||||
def task_format():
|
||||
"""Auto-formats all project files"""
|
||||
return dict(
|
||||
actions=[
|
||||
do("pre-commit run --all-files"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def task_lint():
|
||||
"""Lints all project files"""
|
||||
return dict(
|
||||
actions=[
|
||||
do("flake8 --config=.flake8"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def task_test():
|
||||
"""Runs the test suite"""
|
||||
return dict(
|
||||
actions=[do(f"pytest -c pytest.ini --cov-report={'xml' if CI else 'term'}")],
|
||||
)
|
||||
|
||||
|
||||
def task_build():
|
||||
"""Builds the source distribution and wheel of light-the-torch"""
|
||||
return dict(
|
||||
actions=[
|
||||
do("python -m build ."),
|
||||
],
|
||||
clean=[
|
||||
do(f"rm -rf build dist {PACKAGE_NAME}.egg-info"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def task_publishable():
|
||||
"""Checks if metadata is correct"""
|
||||
yield dict(
|
||||
name="twine",
|
||||
actions=[
|
||||
# We need the lambda here to lazily glob the files in dist/*, since they
|
||||
# are only created by the build task rather than when this task is
|
||||
# created.
|
||||
do(lambda: ["twine", "check", *list((HERE / "dist").glob("*"))]),
|
||||
],
|
||||
task_dep=["build"],
|
||||
)
|
||||
yield dict(
|
||||
name="check-wheel-contents",
|
||||
actions=[
|
||||
do("check-wheel-contents dist"),
|
||||
],
|
||||
task_dep=["build"],
|
||||
)
|
||||
|
||||
|
||||
def task_publish():
|
||||
"""Publishes light-the-torch to PyPI"""
|
||||
return dict(
|
||||
# We need the lambda here to lazily glob the files in dist/*, since they are
|
||||
# only created by the build task rather than when this task is created.
|
||||
actions=[
|
||||
do(lambda: ["twine", "upload", *list((HERE / "dist").glob("*"))]),
|
||||
],
|
||||
task_dep=[
|
||||
"lint",
|
||||
"test",
|
||||
"publishable",
|
||||
],
|
||||
)
|
||||
@@ -1,3 +1,4 @@
|
||||
from ._version import version as __version__ # type: ignore[import]
|
||||
|
||||
from ._pip import *
|
||||
try:
|
||||
from ._version import version as __version__
|
||||
except ImportError:
|
||||
__version__ = "UNKNOWN"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .cli import main
|
||||
from ._cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -2,22 +2,10 @@ import platform
|
||||
import re
|
||||
import subprocess
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Set
|
||||
from typing import Any, List, Optional, Set
|
||||
|
||||
from pip._vendor.packaging.version import InvalidVersion, Version
|
||||
|
||||
__all__ = [
|
||||
"ComputationBackend",
|
||||
"CPUBackend",
|
||||
"CUDABackend",
|
||||
"detect_compatible_computation_backends",
|
||||
]
|
||||
|
||||
|
||||
class ParseError(ValueError):
|
||||
def __init__(self, string: str) -> None:
|
||||
super().__init__(f"Unable to parse {string} into a computation backend")
|
||||
|
||||
|
||||
class ComputationBackend(ABC):
|
||||
@property
|
||||
@@ -45,8 +33,8 @@ class ComputationBackend(ABC):
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, string: str) -> "ComputationBackend":
|
||||
parse_error = ParseError(string)
|
||||
string = string.lower()
|
||||
parse_error = ValueError(f"Unable to parse {string} into a computation backend")
|
||||
string = string.strip().lower()
|
||||
if string == "cpu":
|
||||
return CPUBackend()
|
||||
elif string.startswith("cu"):
|
||||
@@ -97,21 +85,30 @@ class CUDABackend(ComputationBackend):
|
||||
|
||||
|
||||
def _detect_nvidia_driver_version() -> Optional[Version]:
|
||||
cmd = "nvidia-smi --query-gpu=driver_version --format=csv"
|
||||
try:
|
||||
output = (
|
||||
subprocess.check_output(cmd, shell=True, stderr=subprocess.DEVNULL)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
result = subprocess.run(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"--query-gpu=driver_version",
|
||||
"--format=csv",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return Version(output.splitlines()[-1])
|
||||
except (subprocess.CalledProcessError, InvalidVersion):
|
||||
return Version(result.stdout.splitlines()[-1])
|
||||
except (FileNotFoundError, subprocess.CalledProcessError, InvalidVersion):
|
||||
return None
|
||||
|
||||
|
||||
# Table 3 from https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
|
||||
_MINIMUM_DRIVER_VERSIONS = {
|
||||
"Linux": {
|
||||
Version("11.6"): Version("510.39.01"),
|
||||
Version("11.5"): Version("495.29.05"),
|
||||
Version("11.4"): Version("470.82.01"),
|
||||
Version("11.3"): Version("465.19.01"),
|
||||
Version("11.2"): Version("460.32.03"),
|
||||
Version("11.1"): Version("455.32"),
|
||||
Version("11.0"): Version("450.51.06"),
|
||||
Version("10.2"): Version("440.33"),
|
||||
@@ -121,9 +118,13 @@ _MINIMUM_DRIVER_VERSIONS = {
|
||||
Version("9.1"): Version("390.46"),
|
||||
Version("9.0"): Version("384.81"),
|
||||
Version("8.0"): Version("375.26"),
|
||||
Version("7.5"): Version("352.31"),
|
||||
},
|
||||
"Windows": {
|
||||
Version("11.6"): Version("511.23"),
|
||||
Version("11.5"): Version("496.13"),
|
||||
Version("11.4"): Version("472.50"),
|
||||
Version("11.3"): Version("465.89"),
|
||||
Version("11.2"): Version("461.33"),
|
||||
Version("11.1"): Version("456.81"),
|
||||
Version("11.0"): Version("451.82"),
|
||||
Version("10.2"): Version("441.22"),
|
||||
@@ -133,26 +134,25 @@ _MINIMUM_DRIVER_VERSIONS = {
|
||||
Version("9.1"): Version("391.29"),
|
||||
Version("9.0"): Version("385.54"),
|
||||
Version("8.0"): Version("376.51"),
|
||||
Version("7.5"): Version("353.66"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _detect_compatible_cuda_backends() -> Set[CUDABackend]:
|
||||
def _detect_compatible_cuda_backends() -> List[CUDABackend]:
|
||||
driver_version = _detect_nvidia_driver_version()
|
||||
if not driver_version:
|
||||
return set()
|
||||
return []
|
||||
|
||||
minimum_driver_versions = _MINIMUM_DRIVER_VERSIONS.get(platform.system())
|
||||
if not minimum_driver_versions:
|
||||
return set()
|
||||
return []
|
||||
|
||||
return {
|
||||
return [
|
||||
CUDABackend(cuda_version.major, cuda_version.minor)
|
||||
for cuda_version, minimum_driver_version in minimum_driver_versions.items()
|
||||
if driver_version >= minimum_driver_version
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def detect_compatible_computation_backends() -> Set[ComputationBackend]:
|
||||
return {CPUBackend(), *_detect_compatible_cuda_backends()}
|
||||
return {*_detect_compatible_cuda_backends(), CPUBackend()}
|
||||
5
light_the_torch/_cli.py
Normal file
5
light_the_torch/_cli.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pip._internal.cli.main import main as pip_main
|
||||
|
||||
from ._patch import patch
|
||||
|
||||
main = patch(pip_main)
|
||||
294
light_the_torch/_patch.py
Normal file
294
light_the_torch/_patch.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import contextlib
|
||||
import dataclasses
|
||||
|
||||
import enum
|
||||
import functools
|
||||
|
||||
import optparse
|
||||
import re
|
||||
import sys
|
||||
import unittest.mock
|
||||
from typing import List, Set
|
||||
from unittest import mock
|
||||
|
||||
import pip._internal.cli.cmdoptions
|
||||
import pip._internal.index.collector
|
||||
import pip._internal.index.package_finder
|
||||
from pip._internal.index.package_finder import CandidateEvaluator
|
||||
from pip._internal.models.candidate import InstallationCandidate
|
||||
from pip._internal.models.search_scope import SearchScope
|
||||
|
||||
import light_the_torch as ltt
|
||||
|
||||
from . import _cb as cb
|
||||
|
||||
from ._utils import apply_fn_patch
|
||||
|
||||
|
||||
class Channel(enum.Enum):
|
||||
STABLE = enum.auto()
|
||||
TEST = enum.auto()
|
||||
NIGHTLY = enum.auto()
|
||||
LTS = enum.auto()
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, string):
|
||||
return cls[string.upper()]
|
||||
|
||||
|
||||
PYTORCH_DISTRIBUTIONS = ("torch", "torchvision", "torchaudio", "torchtext")
|
||||
|
||||
|
||||
def patch(pip_main):
|
||||
@functools.wraps(pip_main)
|
||||
def wrapper(argv=None):
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
|
||||
with apply_patches(argv):
|
||||
return pip_main(argv)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# adapted from https://stackoverflow.com/a/9307174
|
||||
class PassThroughOptionParser(optparse.OptionParser):
|
||||
def __init__(self):
|
||||
super().__init__(add_help_option=False)
|
||||
|
||||
def _process_args(self, largs, rargs, values):
|
||||
while rargs:
|
||||
try:
|
||||
super()._process_args(largs, rargs, values)
|
||||
except (optparse.BadOptionError, optparse.AmbiguousOptionError) as error:
|
||||
largs.append(error.opt_str)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LttOptions:
|
||||
computation_backends: Set[cb.ComputationBackend] = dataclasses.field(
|
||||
default_factory=lambda: {cb.CPUBackend()}
|
||||
)
|
||||
channel: Channel = Channel.STABLE
|
||||
|
||||
@staticmethod
|
||||
def computation_backend_parser_options():
|
||||
return [
|
||||
optparse.Option(
|
||||
"--pytorch-computation-backend",
|
||||
# TODO: describe multiple inputs
|
||||
help=(
|
||||
"Computation backend for compiled PyTorch distributions, "
|
||||
"e.g. 'cu102', 'cu115', or 'cpu'. "
|
||||
"If not specified, the computation backend is detected from the "
|
||||
"available hardware, preferring CUDA over CPU."
|
||||
),
|
||||
),
|
||||
optparse.Option(
|
||||
"--cpuonly",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Shortcut for '--pytorch-computation-backend=cpu'. "
|
||||
"If '--computation-backend' is used simultaneously, "
|
||||
"it takes precedence over '--cpuonly'."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def channel_parser_option() -> optparse.Option:
|
||||
return optparse.Option(
|
||||
"--pytorch-channel",
|
||||
# FIXME add help text
|
||||
help="",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse(argv):
|
||||
parser = PassThroughOptionParser()
|
||||
|
||||
for option in LttOptions.computation_backend_parser_options():
|
||||
parser.add_option(option)
|
||||
parser.add_option(LttOptions.channel_parser_option())
|
||||
parser.add_option("--pre", dest="pre", action="store_true")
|
||||
|
||||
opts, _ = parser.parse_args(argv)
|
||||
return opts
|
||||
|
||||
@classmethod
|
||||
def from_pip_argv(cls, argv: List[str]):
|
||||
if not argv or argv[0] != "install":
|
||||
return cls()
|
||||
|
||||
opts = cls._parse(argv)
|
||||
|
||||
if opts.pytorch_computation_backend is not None:
|
||||
cbs = {
|
||||
cb.ComputationBackend.from_str(string)
|
||||
for string in opts.pytorch_computation_backend.split(",")
|
||||
}
|
||||
elif opts.cpuonly:
|
||||
cbs = {cb.CPUBackend()}
|
||||
else:
|
||||
cbs = cb.detect_compatible_computation_backends()
|
||||
|
||||
if opts.pytorch_channel is not None:
|
||||
channel = Channel.from_str(opts.pytorch_channel)
|
||||
elif opts.pre:
|
||||
channel = Channel.TEST
|
||||
else:
|
||||
channel = Channel.STABLE
|
||||
|
||||
return cls(cbs, channel)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def apply_patches(argv):
|
||||
options = LttOptions.from_pip_argv(argv)
|
||||
|
||||
patches = [
|
||||
patch_cli_version(),
|
||||
patch_cli_options(),
|
||||
patch_link_collection(options.computation_backends, options.channel),
|
||||
patch_candidate_selection(options.computation_backends),
|
||||
]
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
for patch in patches:
|
||||
stack.enter_context(patch)
|
||||
|
||||
yield stack
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_cli_version():
|
||||
with apply_fn_patch(
|
||||
"pip",
|
||||
"_internal",
|
||||
"cli",
|
||||
"main_parser",
|
||||
"get_pip_version",
|
||||
postprocessing=lambda input, output: f"ltt {ltt.__version__} from {ltt.__path__[0]}\n{output}",
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_cli_options():
|
||||
def postprocessing(input, output):
|
||||
for option in LttOptions.computation_backend_parser_options():
|
||||
input.cmd_opts.add_option(option)
|
||||
|
||||
index_group = pip._internal.cli.cmdoptions.index_group
|
||||
|
||||
with apply_fn_patch(
|
||||
"pip",
|
||||
"_internal",
|
||||
"cli",
|
||||
"cmdoptions",
|
||||
"add_target_python_options",
|
||||
postprocessing=postprocessing,
|
||||
):
|
||||
with unittest.mock.patch.dict(index_group):
|
||||
options = index_group["options"].copy()
|
||||
options.append(LttOptions.channel_parser_option)
|
||||
index_group["options"] = options
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_link_collection(computation_backends, channel):
|
||||
if channel == channel != Channel.LTS:
|
||||
find_links = []
|
||||
# TODO: this template is not valid for all backends
|
||||
channel_path = f"{channel.name.lower()}/" if channel != Channel.STABLE else ""
|
||||
index_urls = [
|
||||
f"https://download.pytorch.org/whl/{channel_path}{backend}"
|
||||
for backend in sorted(computation_backends)
|
||||
]
|
||||
else:
|
||||
# TODO: expand this when there are more LTS versions
|
||||
# TODO: switch this to index_urls when
|
||||
# https://github.com/pytorch/pytorch/pull/74753 is resolved
|
||||
find_links = ["https://download.pytorch.org/whl/lts/1.8/torch_lts.html"]
|
||||
index_urls = []
|
||||
|
||||
search_scope = SearchScope.create(find_links=find_links, index_urls=index_urls)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context(input):
|
||||
if input.project_name not in PYTORCH_DISTRIBUTIONS:
|
||||
yield
|
||||
return
|
||||
|
||||
with mock.patch.object(input.self, "search_scope", search_scope):
|
||||
yield
|
||||
|
||||
with apply_fn_patch(
|
||||
"pip",
|
||||
"_internal",
|
||||
"index",
|
||||
"collector",
|
||||
"LinkCollector",
|
||||
"collect_sources",
|
||||
context=context,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_candidate_selection(computation_backends):
|
||||
allowed_locals = {None, *computation_backends}
|
||||
computation_backend_pattern = re.compile(
|
||||
r"^/whl/(?P<computation_backend>(cpu|cu\d+))/"
|
||||
)
|
||||
|
||||
def preprocessing(input):
|
||||
input.candidates = [
|
||||
candidate
|
||||
for candidate in input.candidates
|
||||
if candidate.name not in PYTORCH_DISTRIBUTIONS
|
||||
or candidate.version.local is None
|
||||
or "rocm" not in candidate.version.local
|
||||
]
|
||||
return input
|
||||
|
||||
def postprocessing(
|
||||
input, output: List[InstallationCandidate]
|
||||
) -> List[InstallationCandidate]:
|
||||
return [
|
||||
candidate
|
||||
for candidate in output
|
||||
if candidate.name not in PYTORCH_DISTRIBUTIONS
|
||||
or candidate.version.local in allowed_locals
|
||||
]
|
||||
|
||||
foo = CandidateEvaluator._sort_key
|
||||
|
||||
def sort_key(candidate_evaluator, candidate):
|
||||
if candidate.name not in PYTORCH_DISTRIBUTIONS:
|
||||
return foo(candidate_evaluator, candidate)
|
||||
|
||||
if candidate.version.local is not None:
|
||||
computation_backend_str = candidate.version.local.replace("any", "cpu")
|
||||
else:
|
||||
match = computation_backend_pattern.match(candidate.link.path)
|
||||
computation_backend_str = match["computation_backend"] if match else "cpu"
|
||||
|
||||
return (
|
||||
cb.ComputationBackend.from_str(computation_backend_str),
|
||||
candidate.version,
|
||||
)
|
||||
|
||||
with apply_fn_patch(
|
||||
"pip",
|
||||
"_internal",
|
||||
"index",
|
||||
"package_finder",
|
||||
"CandidateEvaluator",
|
||||
"get_applicable_candidates",
|
||||
preprocessing=preprocessing,
|
||||
postprocessing=postprocessing,
|
||||
):
|
||||
with unittest.mock.patch.object(CandidateEvaluator, "_sort_key", new=sort_key):
|
||||
yield
|
||||
@@ -1,2 +0,0 @@
|
||||
from .extract import *
|
||||
from .find import *
|
||||
@@ -1,103 +0,0 @@
|
||||
import optparse
|
||||
from typing import Any, Iterable, Type, TypeVar, cast
|
||||
|
||||
from pip._internal.commands.install import InstallCommand
|
||||
from pip._internal.resolution.base import BaseResolver
|
||||
from pip._internal.resolution.legacy.resolver import Resolver
|
||||
from pip._internal.utils.logging import setup_logging
|
||||
from pip._internal.utils.temp_dir import global_tempdir_manager, tempdir_registry
|
||||
|
||||
__all__ = [
|
||||
"InternalLTTError",
|
||||
"PatchedInstallCommand",
|
||||
"make_pip_install_parser",
|
||||
"run",
|
||||
"new_from_similar",
|
||||
"PatchedResolverBase",
|
||||
]
|
||||
|
||||
|
||||
class InternalLTTError(RuntimeError):
|
||||
def __init__(self) -> None:
|
||||
msg = (
|
||||
"Unexpected internal ltt error. If you ever encounter this "
|
||||
"message during normal operation, please submit a bug report at "
|
||||
"https://github.com/pmeier/light-the-torch/issues"
|
||||
)
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class PatchedInstallCommand(InstallCommand):
|
||||
def __init__(
|
||||
self, name: str = "name", summary: str = "summary", **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(name, summary, **kwargs)
|
||||
|
||||
|
||||
def make_pip_install_parser() -> optparse.OptionParser:
|
||||
return cast(optparse.OptionParser, PatchedInstallCommand().parser)
|
||||
|
||||
|
||||
def get_verbosity(options: optparse.Values, verbose: bool) -> int:
|
||||
if not verbose:
|
||||
return -1
|
||||
|
||||
return cast(int, options.verbose) - cast(int, options.quiet)
|
||||
|
||||
|
||||
def run(
|
||||
cmd: InstallCommand, args: Iterable[str], options: optparse.Values, verbose: bool
|
||||
) -> int:
|
||||
with cmd.main_context():
|
||||
cmd.tempdir_registry = cmd.enter_context(tempdir_registry())
|
||||
cmd.enter_context(global_tempdir_manager())
|
||||
|
||||
setup_logging(
|
||||
verbosity=get_verbosity(options, verbose),
|
||||
no_color=options.no_color,
|
||||
user_log_file=options.log,
|
||||
)
|
||||
|
||||
return cast(int, cmd.run(options, list(args)))
|
||||
|
||||
|
||||
def get_public_or_private_attr(obj: Any, name: str) -> Any:
|
||||
try:
|
||||
return getattr(obj, name)
|
||||
except AttributeError:
|
||||
try:
|
||||
return getattr(obj, f"_{name}")
|
||||
except AttributeError:
|
||||
msg = f"'{type(obj)}' has no attribute '{name}' or '_{name}'"
|
||||
raise AttributeError(msg)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def new_from_similar(cls: Type[T], obj: Any, names: Iterable[str], **kwargs: Any) -> T:
|
||||
attrs = {name: get_public_or_private_attr(obj, name) for name in names}
|
||||
attrs.update(kwargs)
|
||||
return cls(**attrs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class PatchedResolverBase(Resolver):
|
||||
@classmethod
|
||||
def from_resolver(cls, resolver: BaseResolver) -> "PatchedResolverBase":
|
||||
return new_from_similar(
|
||||
cls,
|
||||
resolver,
|
||||
(
|
||||
"preparer",
|
||||
"finder",
|
||||
"wheel_cache",
|
||||
"upgrade_strategy",
|
||||
"force_reinstall",
|
||||
"ignore_dependencies",
|
||||
"ignore_installed",
|
||||
"ignore_requires_python",
|
||||
"use_user_site",
|
||||
"make_install_req",
|
||||
"py_version_info",
|
||||
),
|
||||
)
|
||||
@@ -1,105 +0,0 @@
|
||||
import contextlib
|
||||
import re
|
||||
from typing import Any, List, NoReturn, cast
|
||||
|
||||
from pip._internal.req.req_install import InstallRequirement
|
||||
from pip._internal.req.req_set import RequirementSet
|
||||
|
||||
from ..compatibility import find_compatible_torch_version
|
||||
from .common import InternalLTTError, PatchedInstallCommand, PatchedResolverBase, run
|
||||
|
||||
__all__ = ["extract_dists"]
|
||||
|
||||
|
||||
def extract_dists(pip_install_args: List[str], verbose: bool = False) -> List[str]:
|
||||
"""Extract direct or indirect required PyTorch distributions.
|
||||
|
||||
Args:
|
||||
pip_install_args: Arguments passed to ``pip install`` that will be searched for
|
||||
required PyTorch distributions
|
||||
verbose: If ``True``, print additional information to STDOUT.
|
||||
|
||||
Returns:
|
||||
Resolved required PyTorch distributions.
|
||||
"""
|
||||
cmd = StopAfterPytorchDistsFoundInstallCommand()
|
||||
options, args = cmd.parser.parse_args(pip_install_args)
|
||||
try:
|
||||
run(cmd, args, options, verbose)
|
||||
except PytorchDistsFound as resolution:
|
||||
return resolution.dists
|
||||
else:
|
||||
raise InternalLTTError
|
||||
|
||||
|
||||
class PytorchDistsFound(RuntimeError):
|
||||
def __init__(self, dists: List[str]) -> None:
|
||||
self.dists = dists
|
||||
|
||||
|
||||
class StopAfterPytorchDistsFoundResolver(PatchedResolverBase):
|
||||
PYTORCH_CORE = "torch"
|
||||
PYTORCH_SUBS = ("vision", "text", "audio")
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._pytorch_dists = (
|
||||
self.PYTORCH_CORE,
|
||||
*[f"{self.PYTORCH_CORE}{sub}" for sub in self.PYTORCH_SUBS],
|
||||
)
|
||||
self._pytorch_core_pattern = re.compile(
|
||||
f"^{self.PYTORCH_CORE}(?!({'|'.join(self.PYTORCH_SUBS)}))"
|
||||
)
|
||||
self._required_pytorch_dists: List[str] = []
|
||||
|
||||
def _resolve_one(
|
||||
self, requirement_set: RequirementSet, req_to_install: InstallRequirement
|
||||
) -> List[InstallRequirement]:
|
||||
if req_to_install.name not in self._pytorch_dists:
|
||||
return cast(
|
||||
List[InstallRequirement],
|
||||
super()._resolve_one(requirement_set, req_to_install),
|
||||
)
|
||||
|
||||
self._required_pytorch_dists.append(str(req_to_install.req))
|
||||
return []
|
||||
|
||||
def resolve(
|
||||
self, root_reqs: List[InstallRequirement], check_supported_wheels: bool
|
||||
) -> NoReturn:
|
||||
super().resolve(root_reqs, check_supported_wheels)
|
||||
raise PytorchDistsFound(self.required_pytorch_dists)
|
||||
|
||||
@property
|
||||
def required_pytorch_dists(self) -> List[str]:
|
||||
dists = self._required_pytorch_dists
|
||||
if not dists:
|
||||
return []
|
||||
|
||||
# If the distribution was found in an extra requirement, pip appends this as
|
||||
# additional information. We remove that here.
|
||||
dists = [dist.split(";")[0] for dist in dists]
|
||||
|
||||
if not any(self._pytorch_core_pattern.match(dist) for dist in dists):
|
||||
torch = self.PYTORCH_CORE
|
||||
|
||||
with contextlib.suppress(RuntimeError):
|
||||
torch_versions = {
|
||||
find_compatible_torch_version(*dist.split("=="))
|
||||
for dist in dists
|
||||
if "==" in dist
|
||||
}
|
||||
if len(torch_versions) == 1:
|
||||
torch = f"{torch}=={torch_versions.pop()}"
|
||||
|
||||
dists.insert(0, torch)
|
||||
|
||||
return dists
|
||||
|
||||
|
||||
class StopAfterPytorchDistsFoundInstallCommand(PatchedInstallCommand):
|
||||
def make_resolver(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> StopAfterPytorchDistsFoundResolver:
|
||||
resolver = super().make_resolver(*args, **kwargs)
|
||||
return StopAfterPytorchDistsFoundResolver.from_resolver(resolver)
|
||||
@@ -1,400 +0,0 @@
|
||||
import re
|
||||
from typing import (
|
||||
Any,
|
||||
Collection,
|
||||
Iterable,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Set,
|
||||
Text,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pip._internal.index.collector import LinkCollector
|
||||
from pip._internal.index.package_finder import (
|
||||
CandidateEvaluator,
|
||||
CandidatePreferences,
|
||||
LinkEvaluator,
|
||||
PackageFinder,
|
||||
)
|
||||
from pip._internal.models.candidate import InstallationCandidate
|
||||
from pip._internal.models.link import Link
|
||||
from pip._internal.models.search_scope import SearchScope
|
||||
from pip._internal.req.req_install import InstallRequirement
|
||||
from pip._internal.req.req_set import RequirementSet
|
||||
from pip._vendor.packaging.version import Version
|
||||
|
||||
import light_the_torch.computation_backend as cb
|
||||
|
||||
from .common import (
|
||||
InternalLTTError,
|
||||
PatchedInstallCommand,
|
||||
PatchedResolverBase,
|
||||
new_from_similar,
|
||||
run,
|
||||
)
|
||||
from .extract import extract_dists
|
||||
|
||||
__all__ = ["find_links"]
|
||||
|
||||
|
||||
def find_links(
|
||||
pip_install_args: List[str],
|
||||
computation_backends: Optional[
|
||||
Union[cb.ComputationBackend, Collection[cb.ComputationBackend]]
|
||||
] = None,
|
||||
channel: str = "stable",
|
||||
platform: Optional[str] = None,
|
||||
python_version: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
) -> List[str]:
|
||||
"""Find wheel links for direct or indirect PyTorch distributions with given
|
||||
properties.
|
||||
|
||||
Args:
|
||||
pip_install_args: Arguments passed to ``pip install`` that will be searched for
|
||||
required PyTorch distributions
|
||||
computation_backends: Collection of supported computation backends, for example
|
||||
``"cpu"`` or ``"cu102"``. Defaults to the available hardware of the running
|
||||
system.
|
||||
channel: Channel of the PyTorch wheels. Can be one of ``"stable"`` (default),
|
||||
``"lts"``, ``"test"``, and ``"nightly"``.
|
||||
platform: Platform, for example ``"linux_x86_64"`` or ``"win_amd64"``. Defaults
|
||||
to the platform of the running system.
|
||||
python_version: Python version, for example ``"3"`` or ``"3.7"``. Defaults to
|
||||
the version of the running interpreter.
|
||||
verbose: If ``True``, print additional information to STDOUT.
|
||||
|
||||
Returns:
|
||||
Wheel links with given properties for all required PyTorch distributions.
|
||||
"""
|
||||
if computation_backends is None:
|
||||
computation_backends = cb.detect_compatible_computation_backends()
|
||||
elif isinstance(computation_backends, cb.ComputationBackend):
|
||||
computation_backends = {computation_backends}
|
||||
else:
|
||||
computation_backends = set(computation_backends)
|
||||
|
||||
if channel not in ("stable", "lts", "test", "nightly"):
|
||||
raise ValueError(
|
||||
f"channel can be one of 'stable', 'lts', 'test', or 'nightly', "
|
||||
f"but got {channel} instead."
|
||||
)
|
||||
|
||||
dists = extract_dists(pip_install_args)
|
||||
|
||||
cmd = StopAfterPytorchLinksFoundCommand(
|
||||
computation_backends=computation_backends, channel=channel
|
||||
)
|
||||
pip_install_args = adjust_pip_install_args(dists, platform, python_version)
|
||||
options, args = cmd.parser.parse_args(pip_install_args)
|
||||
try:
|
||||
run(cmd, args, options, verbose)
|
||||
except PytorchLinksFound as resolution:
|
||||
return resolution.links
|
||||
else:
|
||||
raise InternalLTTError
|
||||
|
||||
|
||||
def adjust_pip_install_args(
|
||||
pip_install_args: List[str], platform: Optional[str], python_version: Optional[str]
|
||||
) -> List[str]:
|
||||
if platform is None and python_version is None:
|
||||
return pip_install_args
|
||||
|
||||
if platform is not None:
|
||||
pip_install_args = maybe_add_option(
|
||||
pip_install_args, "--platform", value=platform
|
||||
)
|
||||
if python_version is not None:
|
||||
pip_install_args = maybe_add_option(
|
||||
pip_install_args, "--python-version", value=python_version
|
||||
)
|
||||
return maybe_set_required_options(pip_install_args)
|
||||
|
||||
|
||||
def maybe_add_option(
|
||||
args: List[str],
|
||||
option: str,
|
||||
value: Optional[str] = None,
|
||||
aliases: Iterable[str] = (),
|
||||
) -> List[str]:
|
||||
if any(arg in args for arg in (option, *aliases)):
|
||||
return args
|
||||
|
||||
additional_args = [option]
|
||||
if value is not None:
|
||||
additional_args.append(value)
|
||||
return additional_args + args
|
||||
|
||||
|
||||
def maybe_set_required_options(pip_install_args: List[str]) -> List[str]:
|
||||
pip_install_args = maybe_add_option(
|
||||
pip_install_args, "-t", value=".", aliases=("--target",)
|
||||
)
|
||||
pip_install_args = maybe_add_option(
|
||||
pip_install_args, "--only-binary", value=":all:"
|
||||
)
|
||||
return pip_install_args
|
||||
|
||||
|
||||
class PytorchLinksFound(RuntimeError):
|
||||
def __init__(self, links: List[str]) -> None:
|
||||
self.links = links
|
||||
|
||||
|
||||
class PytorchLinkEvaluator(LinkEvaluator):
|
||||
HAS_LOCAL_PATTERN = re.compile(r"[+](cpu|cu\d+)$")
|
||||
EXTRACT_LOCAL_PATTERN = re.compile(r"^/whl/(?P<local_specifier>(cpu|cu\d+))")
|
||||
|
||||
@classmethod
|
||||
def from_link_evaluator(
|
||||
cls, link_evaluator: LinkEvaluator
|
||||
) -> "PytorchLinkEvaluator":
|
||||
return new_from_similar(
|
||||
cls,
|
||||
link_evaluator,
|
||||
(
|
||||
"project_name",
|
||||
"canonical_name",
|
||||
"formats",
|
||||
"target_python",
|
||||
"allow_yanked",
|
||||
"ignore_requires_python",
|
||||
),
|
||||
)
|
||||
|
||||
def evaluate_link(self, link: Link) -> Tuple[bool, Optional[Text]]:
|
||||
output = cast(Tuple[bool, Optional[Text]], super().evaluate_link(link))
|
||||
is_candidate, result = output
|
||||
if not is_candidate:
|
||||
return output
|
||||
|
||||
result = cast(Text, result)
|
||||
has_local = self.HAS_LOCAL_PATTERN.search(result) is not None
|
||||
if has_local:
|
||||
return output
|
||||
|
||||
return True, f"{result}+{self.extract_computation_backend_from_link(link)}"
|
||||
|
||||
def extract_computation_backend_from_link(self, link: Link) -> Optional[str]:
|
||||
match = self.EXTRACT_LOCAL_PATTERN.match(link.path)
|
||||
if match is None:
|
||||
return "any"
|
||||
|
||||
return match.group("local_specifier")
|
||||
|
||||
|
||||
class PytorchCandidatePreferences(CandidatePreferences):
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
computation_backends: Set[cb.ComputationBackend],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.computation_backends = computation_backends
|
||||
|
||||
@classmethod
|
||||
def from_candidate_preferences(
|
||||
cls,
|
||||
candidate_preferences: CandidatePreferences,
|
||||
computation_backends: Set[cb.ComputationBackend],
|
||||
) -> "PytorchCandidatePreferences":
|
||||
return new_from_similar(
|
||||
cls,
|
||||
candidate_preferences,
|
||||
("prefer_binary", "allow_all_prereleases",),
|
||||
computation_backends=computation_backends,
|
||||
)
|
||||
|
||||
|
||||
class PytorchCandidateEvaluator(CandidateEvaluator):
|
||||
_MACOS_PLATFORM_PATTERN = re.compile(r"macosx_\d+_\d+_x86_64")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
computation_backends: Set[cb.ComputationBackend],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.computation_backends = computation_backends
|
||||
|
||||
@classmethod
|
||||
def from_candidate_evaluator(
|
||||
cls,
|
||||
candidate_evaluator: CandidateEvaluator,
|
||||
computation_backends: Set[cb.ComputationBackend],
|
||||
) -> "PytorchCandidateEvaluator":
|
||||
return new_from_similar(
|
||||
cls,
|
||||
candidate_evaluator,
|
||||
(
|
||||
"project_name",
|
||||
"supported_tags",
|
||||
"specifier",
|
||||
"prefer_binary",
|
||||
"allow_all_prereleases",
|
||||
"hashes",
|
||||
),
|
||||
computation_backends=computation_backends,
|
||||
)
|
||||
|
||||
def _sort_key(
|
||||
self, candidate: InstallationCandidate
|
||||
) -> Tuple[cb.ComputationBackend, Version]:
|
||||
return (
|
||||
cb.ComputationBackend.from_str(
|
||||
candidate.version.local.replace("any", "cpu")
|
||||
),
|
||||
candidate.version,
|
||||
)
|
||||
|
||||
def get_applicable_candidates(
|
||||
self, candidates: List[InstallationCandidate]
|
||||
) -> List[InstallationCandidate]:
|
||||
return [
|
||||
candidate
|
||||
for candidate in super().get_applicable_candidates(candidates)
|
||||
if candidate.version.local in self.computation_backends
|
||||
or candidate.version.local == "any"
|
||||
]
|
||||
|
||||
|
||||
class PytorchLinkCollector(LinkCollector):
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
computation_backends: Set[cb.ComputationBackend],
|
||||
channel: str = "stable",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
if channel == "stable":
|
||||
urls = ["https://download.pytorch.org/whl/torch_stable.html"]
|
||||
elif channel == "lts":
|
||||
urls = ["https://download.pytorch.org/whl/lts/1.8/torch_lts.html"]
|
||||
else:
|
||||
urls = [
|
||||
f"https://download.pytorch.org/whl/"
|
||||
f"{channel}/{backend}/torch_{channel}.html"
|
||||
for backend in sorted(computation_backends, key=str)
|
||||
]
|
||||
self.search_scope = SearchScope.create(find_links=urls, index_urls=[])
|
||||
|
||||
@classmethod
|
||||
def from_link_collector(
|
||||
cls,
|
||||
link_collector: LinkCollector,
|
||||
computation_backends: Set[cb.ComputationBackend],
|
||||
channel: str = "stable",
|
||||
) -> "PytorchLinkCollector":
|
||||
return new_from_similar(
|
||||
cls,
|
||||
link_collector,
|
||||
("session", "search_scope"),
|
||||
channel=channel,
|
||||
computation_backends=computation_backends,
|
||||
)
|
||||
|
||||
|
||||
class PytorchPackageFinder(PackageFinder):
|
||||
_candidate_prefs: PytorchCandidatePreferences
|
||||
_link_collector: PytorchLinkCollector
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
computation_backends: Set[cb.ComputationBackend],
|
||||
channel: str = "stable",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._candidate_prefs = PytorchCandidatePreferences.from_candidate_preferences(
|
||||
self._candidate_prefs, computation_backends=computation_backends
|
||||
)
|
||||
self._link_collector = PytorchLinkCollector.from_link_collector(
|
||||
self._link_collector,
|
||||
channel=channel,
|
||||
computation_backends=computation_backends,
|
||||
)
|
||||
|
||||
def make_candidate_evaluator(
|
||||
self, *args: Any, **kwargs: Any,
|
||||
) -> PytorchCandidateEvaluator:
|
||||
candidate_evaluator = super().make_candidate_evaluator(*args, **kwargs)
|
||||
return PytorchCandidateEvaluator.from_candidate_evaluator(
|
||||
candidate_evaluator,
|
||||
computation_backends=self._candidate_prefs.computation_backends,
|
||||
)
|
||||
|
||||
def make_link_evaluator(self, *args: Any, **kwargs: Any) -> PytorchLinkEvaluator:
|
||||
link_evaluator = super().make_link_evaluator(*args, **kwargs)
|
||||
return PytorchLinkEvaluator.from_link_evaluator(link_evaluator)
|
||||
|
||||
@classmethod
|
||||
def from_package_finder(
|
||||
cls,
|
||||
package_finder: PackageFinder,
|
||||
computation_backends: Set[cb.ComputationBackend],
|
||||
channel: str = "stable",
|
||||
) -> "PytorchPackageFinder":
|
||||
return new_from_similar(
|
||||
cls,
|
||||
package_finder,
|
||||
(
|
||||
"link_collector",
|
||||
"target_python",
|
||||
"allow_yanked",
|
||||
"format_control",
|
||||
"candidate_prefs",
|
||||
"ignore_requires_python",
|
||||
),
|
||||
computation_backends=computation_backends,
|
||||
channel=channel,
|
||||
)
|
||||
|
||||
|
||||
class StopAfterPytorchLinksFoundResolver(PatchedResolverBase):
|
||||
def _resolve_one(
|
||||
self, requirement_set: RequirementSet, req_to_install: InstallRequirement
|
||||
) -> List[InstallRequirement]:
|
||||
self._populate_link(req_to_install)
|
||||
return []
|
||||
|
||||
def resolve(
|
||||
self, root_reqs: List[InstallRequirement], check_supported_wheels: bool
|
||||
) -> NoReturn:
|
||||
requirement_set = super().resolve(root_reqs, check_supported_wheels)
|
||||
links = [req.link.url for req in requirement_set.all_requirements]
|
||||
raise PytorchLinksFound(links)
|
||||
|
||||
|
||||
class StopAfterPytorchLinksFoundCommand(PatchedInstallCommand):
|
||||
def __init__(
|
||||
self,
|
||||
computation_backends: Set[cb.ComputationBackend],
|
||||
channel: str = "stable",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.computation_backends = computation_backends
|
||||
self.channel = channel
|
||||
|
||||
def _build_package_finder(self, *args: Any, **kwargs: Any) -> PytorchPackageFinder:
|
||||
package_finder = super()._build_package_finder(*args, **kwargs)
|
||||
return PytorchPackageFinder.from_package_finder(
|
||||
package_finder,
|
||||
computation_backends=self.computation_backends,
|
||||
channel=self.channel,
|
||||
)
|
||||
|
||||
def make_resolver(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> StopAfterPytorchLinksFoundResolver:
|
||||
resolver = super().make_resolver(*args, **kwargs)
|
||||
return StopAfterPytorchLinksFoundResolver.from_resolver(resolver)
|
||||
113
light_the_torch/_utils.py
Normal file
113
light_the_torch/_utils.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import itertools
|
||||
|
||||
from unittest import mock
|
||||
|
||||
|
||||
class InternalError(RuntimeError):
|
||||
def __init__(self) -> None:
|
||||
# TODO: check against pip version
|
||||
# TODO: fix wording
|
||||
msg = (
|
||||
"Unexpected internal pytorch-pip-shim error. If you ever encounter this "
|
||||
"message during normal operation, please submit a bug report at "
|
||||
"https://github.com/pmeier/pytorch-pip-shim/issues"
|
||||
)
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class Input(dict):
|
||||
def __init__(self, fn, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__fn__ = fn
|
||||
|
||||
def __getattr__(self, key):
|
||||
return self[key]
|
||||
|
||||
def __setattr__(self, key, value) -> None:
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key) -> None:
|
||||
del self[key]
|
||||
|
||||
@classmethod
|
||||
def from_call_args(cls, fn, *args, **kwargs):
|
||||
params = iter(inspect.signature(fn).parameters.values())
|
||||
for arg, param in zip(args, params):
|
||||
kwargs[param.name] = arg
|
||||
for param in params:
|
||||
if (
|
||||
param.name not in kwargs
|
||||
and param.default is not inspect.Parameter.empty
|
||||
):
|
||||
kwargs[param.name] = param.default
|
||||
return cls(fn, kwargs)
|
||||
|
||||
def to_call_args(self):
|
||||
params = iter(inspect.signature(self.__fn__).parameters.values())
|
||||
|
||||
args = []
|
||||
for param in params:
|
||||
if param.kind != inspect.Parameter.POSITIONAL_ONLY:
|
||||
break
|
||||
|
||||
args.append(self[param.name])
|
||||
else:
|
||||
return (), dict()
|
||||
args = tuple(args)
|
||||
|
||||
kwargs = dict()
|
||||
sentinel = object()
|
||||
for param in itertools.chain([param], params):
|
||||
kwarg = self.get(param.name, sentinel)
|
||||
if kwarg is not sentinel:
|
||||
kwargs[param.name] = kwarg
|
||||
|
||||
return args, kwargs
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def apply_fn_patch(
|
||||
*parts,
|
||||
preprocessing=lambda input: input,
|
||||
context=contextlib.nullcontext,
|
||||
postprocessing=lambda input, output: output,
|
||||
):
|
||||
target = ".".join(parts)
|
||||
fn = import_fn(target)
|
||||
|
||||
@functools.wraps(fn)
|
||||
def new(*args, **kwargs):
|
||||
input = Input.from_call_args(fn, *args, **kwargs)
|
||||
|
||||
input = preprocessing(input)
|
||||
with context(input):
|
||||
args, kwargs = input.to_call_args()
|
||||
output = fn(*args, **kwargs)
|
||||
return postprocessing(input, output)
|
||||
|
||||
with mock.patch(target, new=new):
|
||||
yield
|
||||
|
||||
|
||||
def import_fn(target: str):
|
||||
attrs = []
|
||||
name = target
|
||||
while name:
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
break
|
||||
except ImportError:
|
||||
name, attr = name.rsplit(".", 1)
|
||||
attrs.append(attr)
|
||||
else:
|
||||
raise InternalError
|
||||
|
||||
obj = module
|
||||
for attr in attrs[::-1]:
|
||||
obj = getattr(obj, attr)
|
||||
|
||||
return obj
|
||||
@@ -1,176 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
from .._pip.common import make_pip_install_parser
|
||||
from .commands import make_command
|
||||
|
||||
__all__ = ["main"]
|
||||
|
||||
|
||||
def main(args: Optional[List[str]] = None) -> None:
|
||||
args = parse_args(args)
|
||||
cmd = make_command(args)
|
||||
|
||||
try:
|
||||
pip_install_args = args.args
|
||||
except AttributeError:
|
||||
pip_install_args = []
|
||||
cmd.run(pip_install_args)
|
||||
|
||||
|
||||
def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
|
||||
parser = make_ltt_parser()
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
class LTTParser(argparse.ArgumentParser):
|
||||
def parse_known_args(
|
||||
self,
|
||||
args: Optional[Sequence[str]] = None,
|
||||
namespace: Optional[argparse.Namespace] = None,
|
||||
) -> Tuple[argparse.Namespace, List[str]]:
|
||||
args, argv = super().parse_known_args(args=args, namespace=namespace)
|
||||
if not argv:
|
||||
return args, argv
|
||||
|
||||
message = (
|
||||
f"Unrecognized arguments: {', '.join(argv)}. If they were meant as "
|
||||
"optional 'pip install' arguments, they have to be passed after a '--' "
|
||||
"seperator."
|
||||
)
|
||||
self.error(message)
|
||||
|
||||
@staticmethod
|
||||
def add_verbose(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help=(
|
||||
"print more output to STDOUT. For fine control use -v / --verbose and "
|
||||
"-q / --quiet of the 'pip install' options"
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_pip_install_args(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"args",
|
||||
nargs="*",
|
||||
help=(
|
||||
"arguments of 'pip install'. Optional arguments have to be seperated "
|
||||
"by '--'"
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_common_arguments(parser: argparse.ArgumentParser) -> None:
|
||||
LTTParser.add_verbose(parser)
|
||||
LTTParser.add_pip_install_args(parser)
|
||||
|
||||
|
||||
def make_ltt_parser() -> LTTParser:
|
||||
parser = LTTParser(prog="ltt")
|
||||
parser.add_argument(
|
||||
"-V",
|
||||
"--version",
|
||||
action="store_true",
|
||||
help="show light-the-torch version and path and exit",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="subcommand", title="subcommands")
|
||||
add_ltt_install_parser(subparsers)
|
||||
add_ltt_extract_parser(subparsers)
|
||||
add_ltt_find_parser(subparsers)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
SubParsers = argparse._SubParsersAction
|
||||
|
||||
|
||||
def add_ltt_install_parser(subparsers: SubParsers) -> None:
|
||||
parser = subparsers.add_parser(
|
||||
"install",
|
||||
description=(
|
||||
"Install PyTorch distributions from the stable releases. The computation "
|
||||
"backend is auto-detected from the available hardware preferring CUDA "
|
||||
"over CPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-cpu",
|
||||
action="store_true",
|
||||
help="disable computation backend auto-detection and use CPU instead",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch-only",
|
||||
action="store_true",
|
||||
help="install only PyTorch distributions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel",
|
||||
type=str,
|
||||
default="stable",
|
||||
help=(
|
||||
"Channel of the PyTorch wheels. "
|
||||
"Can be one of 'stable' (default), 'lts', 'test', or 'nightly'"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--install-cmd",
|
||||
type=str,
|
||||
default="python -m pip install {packages}",
|
||||
help=(
|
||||
"installation command for the PyTorch distributions and additional "
|
||||
"packages. Defaults to 'python -m pip install {packages}'"
|
||||
),
|
||||
)
|
||||
LTTParser.add_common_arguments(parser)
|
||||
|
||||
|
||||
def add_ltt_extract_parser(subparsers: SubParsers) -> None:
|
||||
parser = subparsers.add_parser(
|
||||
"extract", description="Extract required PyTorch distributions"
|
||||
)
|
||||
LTTParser.add_common_arguments(parser)
|
||||
|
||||
|
||||
def add_ltt_find_parser(subparsers: SubParsers) -> None:
|
||||
parser = subparsers.add_parser(
|
||||
"find", description="Find wheel links for the required PyTorch distributions"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--computation-backend",
|
||||
type=str,
|
||||
help=(
|
||||
"Only use wheels compatible with COMPUTATION_BACKEND, for example 'cu102' "
|
||||
"or 'cpu'. Defaults to the computation backend of the running system, "
|
||||
"preferring CUDA over CPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel",
|
||||
type=str,
|
||||
default="stable",
|
||||
help=(
|
||||
"Channel of the PyTorch wheels. "
|
||||
"Can be one of 'stable' (default), 'test', or 'nightly'"
|
||||
),
|
||||
)
|
||||
add_pip_install_arguments(parser, "platform", "python_version")
|
||||
LTTParser.add_common_arguments(parser)
|
||||
|
||||
|
||||
def add_pip_install_arguments(parser: argparse.ArgumentParser, *dests: str) -> None:
|
||||
pip_install_parser = make_pip_install_parser()
|
||||
option_group = pip_install_parser.option_groups[0]
|
||||
for dest in dests:
|
||||
options = [option for option in option_group.option_list if option.dest == dest]
|
||||
assert len(options) == 1
|
||||
option = options[0]
|
||||
|
||||
parser.add_argument(*option._short_opts, *option._long_opts, help=option.help)
|
||||
@@ -1,130 +0,0 @@
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from os import path
|
||||
from typing import Dict, List, NoReturn, Optional, Type
|
||||
|
||||
import light_the_torch as ltt
|
||||
|
||||
from .._pip.common import make_pip_install_parser
|
||||
from ..computation_backend import CPUBackend
|
||||
|
||||
__all__ = ["make_command"]
|
||||
|
||||
|
||||
class Command(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, args: argparse.Namespace) -> None:
|
||||
pass
|
||||
|
||||
def run(self, pip_install_args: List[str]) -> None:
|
||||
self._run(pip_install_args)
|
||||
self.exit()
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, pip_install_args: List[str]) -> None:
|
||||
pass
|
||||
|
||||
def exit(self, code: Optional[int] = None, error: bool = False) -> NoReturn:
|
||||
if code is None:
|
||||
code = 1 if error else 0
|
||||
sys.exit(code)
|
||||
|
||||
|
||||
class GlobalCommand(Command):
|
||||
def __init__(self, args: argparse.Namespace) -> None:
|
||||
self.version = args.version
|
||||
|
||||
def _run(self, pip_install_args: List[str]) -> None:
|
||||
if self.version:
|
||||
root = path.abspath(path.join(path.dirname(__file__), ".."))
|
||||
print(f"{ltt.__name__}=={ltt.__version__} from {root}")
|
||||
|
||||
|
||||
class ExtractCommand(Command):
|
||||
def __init__(self, args: argparse.Namespace) -> None:
|
||||
self.verbose = args.verbose
|
||||
|
||||
def _run(self, pip_install_args: List[str]) -> None:
|
||||
dists = ltt.extract_dists(pip_install_args, verbose=self.verbose)
|
||||
print("\n".join(dists))
|
||||
|
||||
|
||||
class FindCommand(Command):
|
||||
def __init__(self, args: argparse.Namespace) -> None:
|
||||
# TODO split by comma
|
||||
self.computation_backends = args.computation_backend
|
||||
self.channel = args.channel
|
||||
self.platform = args.platform
|
||||
self.python_version = args.python_version
|
||||
self.verbose = args.verbose
|
||||
|
||||
def _run(self, pip_install_args: List[str]) -> None:
|
||||
links = ltt.find_links(
|
||||
pip_install_args,
|
||||
computation_backends=self.computation_backends,
|
||||
channel=self.channel,
|
||||
platform=self.platform,
|
||||
python_version=self.python_version,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
print("\n".join(links))
|
||||
|
||||
|
||||
class InstallCommand(Command):
|
||||
def __init__(self, args: argparse.Namespace) -> None:
|
||||
self.force_cpu = args.force_cpu
|
||||
self.pytorch_only = args.pytorch_only
|
||||
self.channel = args.channel
|
||||
|
||||
install_cmd = args.install_cmd
|
||||
if "{packages}" not in install_cmd:
|
||||
self.exit(error=True)
|
||||
self.install_cmd = install_cmd
|
||||
|
||||
self.verbose = args.verbose
|
||||
|
||||
def _run(self, pip_install_args: List[str]) -> None:
|
||||
links = ltt.find_links(
|
||||
pip_install_args,
|
||||
computation_backends={CPUBackend()} if self.force_cpu else None,
|
||||
channel=self.channel,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if links:
|
||||
cmd = self.install_cmd.format(packages=" ".join(links))
|
||||
subprocess.check_call(cmd, shell=True)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Didn't find any PyTorch distribution in "
|
||||
f"'{' '.join(pip_install_args)}'",
|
||||
RuntimeWarning,
|
||||
)
|
||||
|
||||
if self.pytorch_only:
|
||||
self.exit()
|
||||
|
||||
cmd = self.install_cmd.format(packages=self.collect_packages(pip_install_args))
|
||||
subprocess.check_call(cmd, shell=True)
|
||||
|
||||
def collect_packages(self, pip_install_args: List[str]) -> str:
|
||||
parser = make_pip_install_parser()
|
||||
options, args = parser.parse_args(pip_install_args)
|
||||
editables = [f"-e {e}" for e in options.editables]
|
||||
requirements = [f"-r {r}" for r in options.requirements]
|
||||
return " ".join(editables + requirements + args)
|
||||
|
||||
|
||||
COMMAD_CLASSES: Dict[Optional[str], Type[Command]] = {
|
||||
None: GlobalCommand,
|
||||
"extract": ExtractCommand,
|
||||
"find": FindCommand,
|
||||
"install": InstallCommand,
|
||||
}
|
||||
|
||||
|
||||
def make_command(args: argparse.Namespace) -> Command:
|
||||
cls = COMMAD_CLASSES[args.subcommand]
|
||||
return cls(args)
|
||||
@@ -1,72 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
__all__ = ["find_compatible_torch_version"]
|
||||
|
||||
|
||||
class Version:
|
||||
@classmethod
|
||||
def from_str(cls, version: str) -> "Version":
|
||||
parts = version.split(".")
|
||||
major = int(parts[0])
|
||||
minor = int(parts[1]) if len(parts) > 1 else None
|
||||
patch = int(parts[2]) if len(parts) > 2 else None
|
||||
return cls(major, minor, patch)
|
||||
|
||||
def __init__(self, major: int, minor: Optional[int], patch: Optional[int]) -> None:
|
||||
self.major = major
|
||||
self.minor = minor
|
||||
self.patch = patch
|
||||
self.parts = (major, minor, patch)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Version):
|
||||
return False
|
||||
|
||||
return all(
|
||||
[
|
||||
self_part == other_part
|
||||
for self_part, other_part in zip(self.parts, other.parts)
|
||||
if self_part is not None and other_part is not None
|
||||
]
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.parts)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ".".join([str(part) for part in self.parts if part is not None])
|
||||
|
||||
|
||||
COMPATIBILITY = {
|
||||
"torchvision": {
|
||||
Version(0, 9, 1): Version(1, 8, 1),
|
||||
Version(0, 9, 0): Version(1, 8, 0),
|
||||
Version(0, 8, 0): Version(1, 7, 0),
|
||||
Version(0, 7, 0): Version(1, 6, 0),
|
||||
Version(0, 6, 1): Version(1, 5, 1),
|
||||
Version(0, 6, 0): Version(1, 5, 0),
|
||||
Version(0, 5, 0): Version(1, 4, 0),
|
||||
Version(0, 4, 2): Version(1, 3, 1),
|
||||
Version(0, 4, 1): Version(1, 3, 0),
|
||||
Version(0, 4, 0): Version(1, 2, 0),
|
||||
Version(0, 3, 0): Version(1, 1, 0),
|
||||
Version(0, 2, 2): Version(1, 0, 1),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def find_compatible_torch_version(dist: str, version: str) -> str:
|
||||
version = Version.from_str(version)
|
||||
dist_compatibility = COMPATIBILITY[dist]
|
||||
candidates = [x for x in dist_compatibility.keys() if x == version]
|
||||
if not candidates:
|
||||
raise RuntimeError(
|
||||
f"No compatible torch version was found for {dist}=={version}"
|
||||
)
|
||||
if len(candidates) != 1:
|
||||
raise RuntimeError(
|
||||
f"Multiple compatible torch versions were found for {dist}=={version}:\n"
|
||||
f"{', '.join([str(candidate) for candidate in candidates])}\n"
|
||||
)
|
||||
|
||||
return str(dist_compatibility[candidates[0]])
|
||||
34
mypy.ini
34
mypy.ini
@@ -1,34 +0,0 @@
|
||||
[mypy]
|
||||
; https://mypy.readthedocs.io/en/stable/config_file.html
|
||||
|
||||
; import discovery
|
||||
files = light_the_torch
|
||||
|
||||
; untyped definitions and calls
|
||||
disallow_untyped_defs = True
|
||||
|
||||
; None and Optional handling
|
||||
no_implicit_optional = True
|
||||
|
||||
; warnings
|
||||
warn_redundant_casts = True
|
||||
warn_unused_ignores = True
|
||||
warn_return_any = True
|
||||
warn_unreachable = True
|
||||
|
||||
; miscellaneous strictness flags
|
||||
allow_redefinition = True
|
||||
|
||||
; configuring error messages
|
||||
show_error_context = True
|
||||
show_error_codes = True
|
||||
pretty = True
|
||||
|
||||
; miscellaneous
|
||||
warn_unused_configs = True
|
||||
|
||||
[mypy-light_the_torch]
|
||||
warn_unused_ignores = False
|
||||
|
||||
[mypy-pip.*]
|
||||
ignore_missing_imports = True
|
||||
22
pytest.ini
22
pytest.ini
@@ -1,10 +1,16 @@
|
||||
[pytest]
|
||||
;See link below for available options
|
||||
;https://docs.pytest.org/en/latest/reference.html#ini-options-ref
|
||||
;See https://docs.pytest.org/en/latest/reference.html#ini-options-ref for available
|
||||
; options
|
||||
|
||||
markers =
|
||||
large_download
|
||||
slow
|
||||
flaky
|
||||
testpaths = tests/
|
||||
addopts = -ra
|
||||
addopts =
|
||||
# show summary of all tests that did not pass
|
||||
-ra
|
||||
# Make tracebacks shorter
|
||||
--tb=short
|
||||
# enable all warnings
|
||||
-Wd
|
||||
# coverage
|
||||
--cov=light_the_torch
|
||||
--cov-config=.coveragerc
|
||||
xfail_strict = True
|
||||
testpaths = tests
|
||||
|
||||
@@ -1,2 +1,12 @@
|
||||
tox >= 3.2
|
||||
doit
|
||||
# format & lint
|
||||
pre-commit
|
||||
flake8 ==4.0.1
|
||||
# test
|
||||
pytest
|
||||
pytest-mock
|
||||
pytest-cov
|
||||
# publish
|
||||
build
|
||||
twine
|
||||
check-wheel-contents
|
||||
|
||||
21
setup.cfg
21
setup.cfg
@@ -1,13 +1,13 @@
|
||||
[metadata]
|
||||
name = light_the_torch
|
||||
platforms = any
|
||||
description = Install PyTorch distributions computation backend auto-detection
|
||||
long_description = file: README.rst
|
||||
long_description_content_type = text/x-rst
|
||||
description = Install PyTorch distributions with computation backend auto-detection
|
||||
long_description = file: README.md
|
||||
long_description_content_type = text/markdown; charset=UTF-8; variant=GFM
|
||||
keywords = pytorch, cuda, pip, install
|
||||
url = https://github.com/pmeier/light-the-torch
|
||||
author = Philip Meier
|
||||
author-email = github.pmeier@posteo.de
|
||||
author_email = github.pmeier@posteo.de
|
||||
license = BSD-3-Clause
|
||||
classifiers =
|
||||
Development Status :: 3 - Alpha
|
||||
@@ -15,12 +15,12 @@ classifiers =
|
||||
Environment :: GPU :: NVIDIA CUDA
|
||||
Intended Audience :: Developers
|
||||
License :: OSI Approved :: BSD License
|
||||
Programming Language :: Python :: 3.6
|
||||
Programming Language :: Python :: 3.7
|
||||
Programming Language :: Python :: 3.8
|
||||
Programming Language :: Python :: 3.9
|
||||
Programming Language :: Python :: 3.10
|
||||
Topic :: System :: Installation/Setup
|
||||
Topic :: Utilities
|
||||
Typing :: Typed
|
||||
project_urls =
|
||||
Source = https://github.com/pmeier/light-the-torch
|
||||
Tracker = https://github.com/pmeier/light-the-torch/issues
|
||||
@@ -28,15 +28,14 @@ project_urls =
|
||||
[options]
|
||||
packages = find:
|
||||
include_package_data = True
|
||||
python_requires = >=3.6
|
||||
python_requires = >=3.7
|
||||
install_requires =
|
||||
pip >=20.1.*, <20.3.*
|
||||
pip ==22.0.*, != 22.0, != 22.0.1, != 22.0.2
|
||||
|
||||
[options.packages.find]
|
||||
exclude =
|
||||
tests
|
||||
tests.*
|
||||
tests*
|
||||
|
||||
[options.entry_points]
|
||||
console_scripts =
|
||||
ltt=light_the_torch.cli:main
|
||||
ltt=light_the_torch._cli:main
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
import io
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_argv(mocker):
|
||||
def patch_argv_(*args):
|
||||
return mocker.patch.object(sys, "argv", ["ltt", *args])
|
||||
|
||||
return patch_argv_
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_extract_dists(mocker):
|
||||
def patch_extract_dists_(return_value=None):
|
||||
if return_value is None:
|
||||
return_value = []
|
||||
return mocker.patch(
|
||||
"light_the_torch.cli.commands.ltt.extract_dists", return_value=return_value,
|
||||
)
|
||||
|
||||
return patch_extract_dists_
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_find_links(mocker):
|
||||
def patch_find_links_(return_value=None):
|
||||
if return_value is None:
|
||||
return_value = []
|
||||
return mocker.patch(
|
||||
"light_the_torch.cli.commands.ltt.find_links", return_value=return_value,
|
||||
)
|
||||
|
||||
return patch_find_links_
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_stdout(mocker):
|
||||
def patch_stdout_():
|
||||
return mocker.patch.object(sys, "stdout", io.StringIO())
|
||||
|
||||
return patch_stdout_
|
||||
@@ -1,52 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from light_the_torch import cli
|
||||
|
||||
from .utils import exits
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_extract_argv(patch_argv):
|
||||
def patch_extract_argv_(*args):
|
||||
return patch_argv("extract", *args)
|
||||
|
||||
return patch_extract_argv_
|
||||
|
||||
|
||||
def test_ltt_extract(subtests, patch_extract_argv, patch_extract_dists, patch_stdout):
|
||||
pip_install_args = ["foo"]
|
||||
dists = ["bar", "baz"]
|
||||
|
||||
patch_extract_argv(*pip_install_args)
|
||||
extract_dists = patch_extract_dists(dists)
|
||||
stdout = patch_stdout()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
with subtests.test("extract_dists"):
|
||||
args, _ = extract_dists.call_args
|
||||
assert args[0] == pip_install_args
|
||||
|
||||
with subtests.test("stdout"):
|
||||
output = stdout.getvalue().strip()
|
||||
assert output == "\n".join(dists)
|
||||
|
||||
|
||||
def test_ltt_extract_verbose(patch_extract_argv, patch_extract_dists):
|
||||
patch_extract_argv("--verbose")
|
||||
extract_dists = patch_extract_dists([])
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
_, kwargs = extract_dists.call_args
|
||||
assert "verbose" in kwargs
|
||||
assert kwargs["verbose"]
|
||||
|
||||
|
||||
def test_extract_unrecognized_argument(patch_extract_argv):
|
||||
patch_extract_argv("--unrecognized-argument")
|
||||
|
||||
with exits(error=True):
|
||||
cli.main()
|
||||
@@ -1,52 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from light_the_torch import cli
|
||||
|
||||
from .utils import exits
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_find_argv(patch_argv):
|
||||
def patch_find_argv_(*args):
|
||||
return patch_argv("find", *args)
|
||||
|
||||
return patch_find_argv_
|
||||
|
||||
|
||||
def test_ltt_find(subtests, patch_find_argv, patch_find_links, patch_stdout):
|
||||
pip_install_args = ["foo"]
|
||||
links = ["bar", "baz"]
|
||||
|
||||
patch_find_argv(*pip_install_args)
|
||||
find_links = patch_find_links(links)
|
||||
stdout = patch_stdout()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
with subtests.test("find_links"):
|
||||
args, _ = find_links.call_args
|
||||
assert args[0] == pip_install_args
|
||||
|
||||
with subtests.test("stdout"):
|
||||
output = stdout.getvalue().strip()
|
||||
assert output == "\n".join(links)
|
||||
|
||||
|
||||
def test_ltt_find_verbose(patch_find_argv, patch_find_links):
|
||||
patch_find_argv("--verbose")
|
||||
find_links = patch_find_links([])
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
_, kwargs = find_links.call_args
|
||||
assert "verbose" in kwargs
|
||||
assert kwargs["verbose"]
|
||||
|
||||
|
||||
def test_find_unrecognized_argument(patch_find_argv):
|
||||
patch_find_argv("--unrecognized-argument")
|
||||
|
||||
with exits(error=True):
|
||||
cli.main()
|
||||
@@ -1,46 +0,0 @@
|
||||
import subprocess
|
||||
|
||||
import light_the_torch as ltt
|
||||
from light_the_torch import cli
|
||||
|
||||
from .utils import exits
|
||||
|
||||
|
||||
def test_ltt_main_smoke(subtests):
|
||||
for arg in ("-h", "-V"):
|
||||
cmd = f"python -m light_the_torch {arg}"
|
||||
with subtests.test(cmd=cmd):
|
||||
subprocess.check_call(cmd, shell=True)
|
||||
|
||||
|
||||
def test_ltt_help_smoke(subtests, patch_argv, patch_stdout):
|
||||
for arg in ("-h", "--help"):
|
||||
with subtests.test(arg=arg):
|
||||
patch_argv(arg)
|
||||
stdout = patch_stdout()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
assert stdout.getvalue().strip()
|
||||
|
||||
|
||||
def test_ltt_version(subtests, patch_argv, patch_stdout):
|
||||
for arg in ("-V", "--version"):
|
||||
with subtests.test(arg=arg):
|
||||
patch_argv(arg)
|
||||
stdout = patch_stdout()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
output = stdout.getvalue().strip()
|
||||
assert output.startswith(f"{ltt.__name__}=={ltt.__version__}")
|
||||
|
||||
|
||||
def test_ltt_unknown_subcommand(patch_argv):
|
||||
subcommand = "unkown"
|
||||
patch_argv(subcommand)
|
||||
|
||||
with exits(error=True):
|
||||
cli.main()
|
||||
@@ -1,215 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from light_the_torch import cli
|
||||
from light_the_torch.computation_backend import CPUBackend
|
||||
|
||||
from .utils import exits
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_install_argv(patch_argv):
|
||||
def patch_install_argv_(*args):
|
||||
return patch_argv("install", *args)
|
||||
|
||||
return patch_install_argv_
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_subprocess_call(mocker):
|
||||
def patch_subprocess_call_():
|
||||
return mocker.patch("light_the_torch.cli.commands.subprocess.check_call")
|
||||
|
||||
return patch_subprocess_call_
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_collect_packages(mocker):
|
||||
def patch_collect_packages_(return_value=None):
|
||||
if return_value is None:
|
||||
return_value = []
|
||||
return mocker.patch(
|
||||
"light_the_torch.cli.commands.InstallCommand.collect_packages",
|
||||
return_value=return_value,
|
||||
)
|
||||
|
||||
return patch_collect_packages_
|
||||
|
||||
|
||||
def test_ltt_install(
|
||||
subtests, patch_install_argv, patch_find_links, patch_subprocess_call
|
||||
):
|
||||
install_cmd = "python -m pip install {packages}"
|
||||
pip_install_args = ["foo", "bar"]
|
||||
links = ["https://foo.org"]
|
||||
|
||||
patch_install_argv(*pip_install_args)
|
||||
patch_find_links(links)
|
||||
subprocess_call = patch_subprocess_call()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
assert subprocess_call.call_count == 2
|
||||
|
||||
with subtests.test("install PyTorch"):
|
||||
call_args = subprocess_call.call_args_list[0]
|
||||
args, _ = call_args
|
||||
assert args[0] == install_cmd.format(packages=" ".join(links))
|
||||
|
||||
with subtests.test("install remainder"):
|
||||
call_args = subprocess_call.call_args_list[1]
|
||||
args, _ = call_args
|
||||
assert args[0] == install_cmd.format(packages=" ".join(pip_install_args))
|
||||
|
||||
|
||||
def test_ltt_install_force_cpu(
|
||||
patch_install_argv, patch_find_links, patch_subprocess_call, patch_collect_packages,
|
||||
):
|
||||
patch_install_argv("--force-cpu")
|
||||
find_links = patch_find_links()
|
||||
patch_subprocess_call()
|
||||
patch_collect_packages()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
_, kwargs = find_links.call_args
|
||||
assert "computation_backends" in kwargs
|
||||
assert set(kwargs["computation_backends"]) == {CPUBackend()}
|
||||
|
||||
|
||||
def test_ltt_install_pytorch_only(
|
||||
patch_install_argv, patch_find_links, patch_subprocess_call, patch_collect_packages,
|
||||
):
|
||||
patch_install_argv("--pytorch-only")
|
||||
patch_find_links()
|
||||
patch_subprocess_call()
|
||||
collect_packages = patch_collect_packages()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
collect_packages.assert_not_called()
|
||||
|
||||
|
||||
def test_ltt_install_channel(
|
||||
patch_install_argv, patch_find_links, patch_subprocess_call, patch_collect_packages,
|
||||
):
|
||||
channel = "channel"
|
||||
|
||||
patch_install_argv(f"--channel={channel}")
|
||||
find_links = patch_find_links()
|
||||
patch_subprocess_call()
|
||||
patch_collect_packages()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
_, kwargs = find_links.call_args
|
||||
assert "channel" in kwargs
|
||||
assert kwargs["channel"] == channel
|
||||
|
||||
|
||||
def test_ltt_install_install_cmd(
|
||||
patch_install_argv, patch_find_links, patch_subprocess_call,
|
||||
):
|
||||
install_cmd = "custom install {packages}"
|
||||
packages = ["foo", "bar"]
|
||||
|
||||
patch_install_argv("--pytorch-only", "--install-cmd", install_cmd)
|
||||
patch_find_links(packages)
|
||||
subprocess_call = patch_subprocess_call()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
args, _ = subprocess_call.call_args
|
||||
assert args[0] == install_cmd.format(packages=" ".join(packages))
|
||||
|
||||
|
||||
def test_ltt_install_install_cmd_no_subs(patch_install_argv):
|
||||
patch_install_argv("--install-cmd", "no proper packages substitution")
|
||||
|
||||
with exits(error=True):
|
||||
cli.main()
|
||||
|
||||
|
||||
def test_ltt_install_editables(
|
||||
patch_install_argv, patch_find_links, patch_subprocess_call,
|
||||
):
|
||||
install_cmd = "custom install {packages}"
|
||||
editables = [".", "foo"]
|
||||
args = ["bar", "baz"]
|
||||
cmd = install_cmd.format(packages=" ".join([f"-e {e}" for e in editables] + args))
|
||||
|
||||
patch_install_argv(
|
||||
"--install-cmd",
|
||||
install_cmd,
|
||||
"--",
|
||||
"-e",
|
||||
editables[0],
|
||||
"--editable",
|
||||
editables[1],
|
||||
*args,
|
||||
)
|
||||
patch_find_links()
|
||||
subprocess_call = patch_subprocess_call()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
args, _ = subprocess_call.call_args
|
||||
assert args[0] == cmd
|
||||
|
||||
|
||||
def test_ltt_install_requirements(
|
||||
patch_install_argv, patch_find_links, patch_subprocess_call,
|
||||
):
|
||||
install_cmd = "custom install {packages}"
|
||||
requirements = ["requirements.txt", "requirements-dev.txt"]
|
||||
args = ["foo", "bar"]
|
||||
cmd = install_cmd.format(
|
||||
packages=" ".join([f"-r {r}" for r in requirements] + args)
|
||||
)
|
||||
|
||||
patch_install_argv(
|
||||
"--install-cmd",
|
||||
install_cmd,
|
||||
"--",
|
||||
"-r",
|
||||
requirements[0],
|
||||
"--requirement",
|
||||
requirements[1],
|
||||
*args,
|
||||
)
|
||||
patch_find_links()
|
||||
subprocess_call = patch_subprocess_call()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
args, _ = subprocess_call.call_args
|
||||
assert args[0] == cmd
|
||||
|
||||
|
||||
def test_ltt_install_verbose(
|
||||
patch_install_argv, patch_find_links, patch_subprocess_call, patch_collect_packages,
|
||||
):
|
||||
patch_install_argv("--verbose")
|
||||
find_links = patch_find_links()
|
||||
patch_subprocess_call()
|
||||
patch_collect_packages()
|
||||
|
||||
with exits():
|
||||
cli.main()
|
||||
|
||||
_, kwargs = find_links.call_args
|
||||
assert "verbose" in kwargs
|
||||
assert kwargs["verbose"]
|
||||
|
||||
|
||||
def test_install_unrecognized_argument(patch_install_argv):
|
||||
patch_install_argv("--unrecognized-argument")
|
||||
|
||||
with exits(error=True):
|
||||
cli.main()
|
||||
@@ -1,21 +0,0 @@
|
||||
import contextlib
|
||||
|
||||
import pytest
|
||||
|
||||
__all__ = ["exits"]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def exits(code=None, error=False):
|
||||
with pytest.raises(SystemExit) as info:
|
||||
yield
|
||||
|
||||
ret = info.value.code
|
||||
|
||||
if code is not None:
|
||||
assert ret == code
|
||||
|
||||
if error:
|
||||
assert ret >= 1
|
||||
else:
|
||||
assert ret is None or ret == 0
|
||||
@@ -1,77 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
class MarkConfig:
|
||||
def __init__(
|
||||
self,
|
||||
keyword,
|
||||
run_by_default,
|
||||
addoption=True,
|
||||
option=None,
|
||||
help=None,
|
||||
condition_for_skip=None,
|
||||
reason=None,
|
||||
):
|
||||
self.addoption = addoption
|
||||
|
||||
if option is None:
|
||||
option = (
|
||||
f"--{'skip' if run_by_default else 'run'}-{keyword.replace('_', '-')}"
|
||||
)
|
||||
self.option = option
|
||||
|
||||
if help is None:
|
||||
help = (
|
||||
f"{'Skip' if run_by_default else 'Run'} tests decorated with @{keyword}"
|
||||
)
|
||||
self.help = help
|
||||
|
||||
if condition_for_skip is None:
|
||||
|
||||
def condition_for_skip(config, item):
|
||||
has_keyword = keyword in item.keywords
|
||||
if run_by_default:
|
||||
return has_keyword and config.getoption(option)
|
||||
else:
|
||||
return has_keyword and not config.getoption(option)
|
||||
|
||||
self.condition_for_skip = condition_for_skip
|
||||
|
||||
if reason is None:
|
||||
reason = (
|
||||
f"Test is {keyword} and {option} was "
|
||||
f"{'' if run_by_default else 'not '}given."
|
||||
)
|
||||
self.marker = pytest.mark.skip(reason=reason)
|
||||
|
||||
|
||||
MARK_CONFIGS = (
|
||||
MarkConfig(
|
||||
keyword="large_download",
|
||||
run_by_default=True,
|
||||
reason=(
|
||||
"Test possibly includes a large download and --skip-large-download was "
|
||||
"given."
|
||||
),
|
||||
),
|
||||
MarkConfig(keyword="slow", run_by_default=True),
|
||||
MarkConfig(keyword="flaky", run_by_default=False),
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
for mark_config in MARK_CONFIGS:
|
||||
if mark_config.addoption:
|
||||
parser.addoption(
|
||||
mark_config.option,
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=mark_config.help,
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
for item in items:
|
||||
for mark_config in MARK_CONFIGS:
|
||||
if mark_config.condition_for_skip(config, item):
|
||||
item.add_marker(mark_config.marker)
|
||||
115
tests/test_cli.py
Normal file
115
tests/test_cli.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import contextlib
|
||||
import io
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import light_the_torch as ltt
|
||||
import pytest
|
||||
|
||||
from light_the_torch._cli import main
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cmd", ["ltt", "python -m light_the_torch"])
|
||||
def test_entry_point_smoke(cmd):
|
||||
subprocess.run(shlex.split(cmd), shell=False)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def exits(*, should_succeed=True, expected_code=None, check_err=None, check_out=None):
|
||||
def parse_checker(checker):
|
||||
if checker is None or callable(checker):
|
||||
return checker
|
||||
|
||||
if isinstance(checker, str):
|
||||
checker = (checker,)
|
||||
|
||||
def check_fn(text):
|
||||
for phrase in checker:
|
||||
assert phrase in text
|
||||
|
||||
return check_fn
|
||||
|
||||
check_err = parse_checker(check_err)
|
||||
check_out = parse_checker(check_out)
|
||||
|
||||
with pytest.raises(SystemExit) as info:
|
||||
with contextlib.redirect_stderr(io.StringIO()) as raw_err:
|
||||
with contextlib.redirect_stdout(io.StringIO()) as raw_out:
|
||||
yield
|
||||
|
||||
returned_code = info.value.code or 0
|
||||
succeeded = returned_code == 0
|
||||
err = raw_err.getvalue().strip()
|
||||
out = raw_out.getvalue().strip()
|
||||
|
||||
if expected_code is not None:
|
||||
if returned_code == expected_code:
|
||||
return
|
||||
|
||||
raise AssertionError(
|
||||
f"Returned and expected return code mismatch: "
|
||||
f"{returned_code} != {expected_code}."
|
||||
)
|
||||
|
||||
if should_succeed:
|
||||
if succeeded:
|
||||
if check_out:
|
||||
check_out(out)
|
||||
|
||||
return
|
||||
|
||||
raise AssertionError(
|
||||
f"Program should have succeeded, but returned code {returned_code} "
|
||||
f"and printed the following to STDERR: '{err}'."
|
||||
)
|
||||
else:
|
||||
if not succeeded:
|
||||
if check_err:
|
||||
check_err(err)
|
||||
|
||||
return
|
||||
|
||||
raise AssertionError("Program shouldn't have succeeded, but did.")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_argv(mocker):
|
||||
def patch(*options):
|
||||
return mocker.patch.object(sys, "argv", ["ltt", *options])
|
||||
|
||||
return patch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("option", ["-h", "--help"])
|
||||
def test_help_smoke(set_argv, option):
|
||||
set_argv(option)
|
||||
|
||||
def check_out(out):
|
||||
assert out
|
||||
|
||||
with exits(check_out=check_out):
|
||||
main()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("option", ["-V", "--version"])
|
||||
def test_version(set_argv, option):
|
||||
set_argv(option)
|
||||
|
||||
with exits(check_out=f"ltt {ltt.__version__} from {ltt.__path__[0]}"):
|
||||
main()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"option",
|
||||
[
|
||||
"--pytorch-computation-backend",
|
||||
"--cpuonly",
|
||||
"--pytorch-channel",
|
||||
],
|
||||
)
|
||||
def test_ltt_options_smoke(set_argv, option):
|
||||
set_argv("install", "--help")
|
||||
|
||||
with exits(check_out=option):
|
||||
main()
|
||||
@@ -1,8 +1,40 @@
|
||||
import subprocess
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from light_the_torch import computation_backend as cb
|
||||
from light_the_torch import _cb as cb
|
||||
|
||||
try:
|
||||
subprocess.check_call(
|
||||
"nvidia-smi",
|
||||
shell=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
NVIDIA_DRIVER_AVAILABLE = True
|
||||
except subprocess.CalledProcessError:
|
||||
NVIDIA_DRIVER_AVAILABLE = False
|
||||
|
||||
|
||||
skip_if_nvidia_driver_unavailable = pytest.mark.skipif(
|
||||
not NVIDIA_DRIVER_AVAILABLE, reason="Requires nVidia driver."
|
||||
)
|
||||
|
||||
|
||||
class GenericComputationBackend(cb.ComputationBackend):
|
||||
@property
|
||||
def local_specifier(self):
|
||||
return "generic"
|
||||
|
||||
def __lt__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generic_backend():
|
||||
return GenericComputationBackend()
|
||||
|
||||
|
||||
class TestComputationBackend:
|
||||
@@ -42,7 +74,7 @@ class TestComputationBackend:
|
||||
|
||||
@pytest.mark.parametrize("string", (("unknown", "cudnn")))
|
||||
def test_from_str_unknown(self, string):
|
||||
with pytest.raises(cb.ParseError):
|
||||
with pytest.raises(ValueError, match=string):
|
||||
cb.ComputationBackend.from_str(string)
|
||||
|
||||
|
||||
@@ -70,26 +102,12 @@ class TestOrdering:
|
||||
assert cb.CUDABackend(2, 1) < cb.CUDABackend(10, 0)
|
||||
|
||||
|
||||
try:
|
||||
subprocess.check_call(
|
||||
"nvidia-smi", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
||||
)
|
||||
NVIDIA_DRIVER_AVAILABLE = True
|
||||
except subprocess.CalledProcessError:
|
||||
NVIDIA_DRIVER_AVAILABLE = False
|
||||
|
||||
|
||||
skip_if_nvidia_driver_unavailable = pytest.mark.skipif(
|
||||
not NVIDIA_DRIVER_AVAILABLE, reason="Requires nVidia driver."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_nvidia_driver_version(mocker):
|
||||
def factory(version):
|
||||
return mocker.patch(
|
||||
"light_the_torch.computation_backend.subprocess.check_output",
|
||||
return_value=f"driver_version\n{version}".encode("utf-8"),
|
||||
"light_the_torch._cb.subprocess.run",
|
||||
return_value=SimpleNamespace(stdout=f"driver_version\n{version}"),
|
||||
)
|
||||
|
||||
return factory
|
||||
@@ -129,7 +147,9 @@ def cuda_backends_params():
|
||||
pytest.param(
|
||||
system,
|
||||
str(driver_versions[idx]),
|
||||
set(cuda_backends[: idx + 1],),
|
||||
set(
|
||||
cuda_backends[: idx + 1],
|
||||
),
|
||||
id=f"{system.lower()}-normal",
|
||||
)
|
||||
)
|
||||
@@ -142,7 +162,7 @@ def cuda_backends_params():
|
||||
class TestDetectCompatibleComputationBackends:
|
||||
def test_no_nvidia_driver(self, mocker):
|
||||
mocker.patch(
|
||||
"light_the_torch.computation_backend.subprocess.check_output",
|
||||
"light_the_torch._cb.subprocess.run",
|
||||
side_effect=subprocess.CalledProcessError(1, ""),
|
||||
)
|
||||
|
||||
@@ -157,9 +177,7 @@ class TestDetectCompatibleComputationBackends:
|
||||
nvidia_driver_version,
|
||||
compatible_cuda_backends,
|
||||
):
|
||||
mocker.patch(
|
||||
"light_the_torch.computation_backend.platform.system", return_value=system
|
||||
)
|
||||
mocker.patch("light_the_torch._cb.platform.system", return_value=system)
|
||||
patch_nvidia_driver_version(nvidia_driver_version)
|
||||
|
||||
backends = cb.detect_compatible_computation_backends()
|
||||
121
tests/test_smoke.py
Normal file
121
tests/test_smoke.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import builtins
|
||||
import importlib
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
PACKAGE_NAME = "light_the_torch"
|
||||
PROJECT_ROOT = (pathlib.Path(__file__).parent / "..").resolve()
|
||||
PACKAGE_ROOT = PROJECT_ROOT / PACKAGE_NAME
|
||||
|
||||
|
||||
def collect_modules():
|
||||
def is_private(path):
|
||||
return pathlib.Path(path).name.startswith("_")
|
||||
|
||||
def path_to_module(path):
|
||||
return str(pathlib.Path(path).with_suffix("")).replace(os.sep, ".")
|
||||
|
||||
modules = []
|
||||
for root, dirs, files in os.walk(PACKAGE_ROOT):
|
||||
if is_private(root) or "__init__.py" not in files:
|
||||
del dirs[:]
|
||||
continue
|
||||
|
||||
path = pathlib.Path(root).relative_to(PROJECT_ROOT)
|
||||
modules.append(path_to_module(path))
|
||||
|
||||
for file in files:
|
||||
if is_private(file) or not file.endswith(".py"):
|
||||
continue
|
||||
|
||||
modules.append(path_to_module(path / file))
|
||||
|
||||
return modules
|
||||
|
||||
|
||||
@pytest.mark.parametrize("module", collect_modules())
|
||||
def test_importability(module):
|
||||
importlib.import_module(module)
|
||||
|
||||
|
||||
def import_package_under_test():
|
||||
try:
|
||||
return importlib.import_module(PACKAGE_NAME)
|
||||
except Exception as error:
|
||||
raise RuntimeError(
|
||||
f"The package '{PACKAGE_NAME}' could not be imported. "
|
||||
f"Check the results of tests/test_smoke.py::test_importability for details."
|
||||
) from error
|
||||
|
||||
|
||||
def test_version_installed():
|
||||
def is_canonical(version):
|
||||
# Copied from
|
||||
# https://www.python.org/dev/peps/pep-0440/#appendix-b-parsing-version-strings-with-regular-expressions
|
||||
return (
|
||||
re.match(
|
||||
r"^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$",
|
||||
version,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
def is_dev(version):
|
||||
match = re.search(r"\+g[\da-f]{7}([.]\d{14})?", version)
|
||||
if match is not None:
|
||||
return is_canonical(version[: match.span()[0]])
|
||||
else:
|
||||
return False
|
||||
|
||||
put = import_package_under_test()
|
||||
assert is_canonical(put.__version__) or is_dev(put.__version__)
|
||||
|
||||
|
||||
def patch_imports(
|
||||
mocker,
|
||||
*names,
|
||||
retain_condition=None,
|
||||
import_error_condition=None,
|
||||
):
|
||||
if retain_condition is None:
|
||||
|
||||
def retain_condition(name):
|
||||
return not any(name.startswith(name_) for name_ in names)
|
||||
|
||||
if import_error_condition is None:
|
||||
|
||||
def import_error_condition(name, globals, locals, fromlist, level):
|
||||
direct = name in names
|
||||
indirect = fromlist is not None and any(
|
||||
from_ in names for from_ in fromlist
|
||||
)
|
||||
return direct or indirect
|
||||
|
||||
__import__ = builtins.__import__
|
||||
|
||||
def patched_import(name, globals, locals, fromlist, level):
|
||||
if import_error_condition(name, globals, locals, fromlist, level):
|
||||
raise ImportError()
|
||||
|
||||
return __import__(name, globals, locals, fromlist, level)
|
||||
|
||||
mocker.patch.object(builtins, "__import__", new=patched_import)
|
||||
|
||||
values = {
|
||||
name: module for name, module in sys.modules.items() if retain_condition(name)
|
||||
}
|
||||
mocker.patch.dict(sys.modules, clear=True, values=values)
|
||||
|
||||
|
||||
def test_version_not_installed(mocker):
|
||||
def import_error_condition(name, globals, locals, fromlist, level):
|
||||
return name == "_version" and fromlist == ("version",)
|
||||
|
||||
patch_imports(mocker, PACKAGE_NAME, import_error_condition=import_error_condition)
|
||||
|
||||
put = import_package_under_test()
|
||||
assert put.__version__ == "UNKNOWN"
|
||||
@@ -1,79 +0,0 @@
|
||||
import itertools
|
||||
import optparse
|
||||
|
||||
import pytest
|
||||
|
||||
from light_the_torch._pip import common
|
||||
|
||||
|
||||
def test_get_verbosity(subtests):
|
||||
verboses = tuple(range(4))
|
||||
quiets = tuple(range(4))
|
||||
|
||||
for verbose, quiet in itertools.product(verboses, quiets):
|
||||
with subtests.test(verbose=verbose, quiet=quiet):
|
||||
options = optparse.Values({"verbose": verbose, "quiet": quiet})
|
||||
verbosity = verbose - quiet
|
||||
|
||||
assert common.get_verbosity(options, verbose=True) == verbosity
|
||||
assert common.get_verbosity(options, verbose=False) == -1
|
||||
|
||||
|
||||
def test_get_public_or_private_attr_public_and_private():
|
||||
class ObjWithPublicAndPrivateAttribute:
|
||||
attr = "public"
|
||||
_attr = "private"
|
||||
|
||||
obj = ObjWithPublicAndPrivateAttribute()
|
||||
assert common.get_public_or_private_attr(obj, "attr") == "public"
|
||||
|
||||
|
||||
def test_get_public_or_private_attr_public_only():
|
||||
class ObjWithPublicAttribute:
|
||||
attr = "public"
|
||||
|
||||
obj = ObjWithPublicAttribute()
|
||||
assert common.get_public_or_private_attr(obj, "attr") == "public"
|
||||
|
||||
|
||||
def test_get_public_or_private_attr_private_only():
|
||||
class ObjWithPrivateAttribute:
|
||||
_attr = "private"
|
||||
|
||||
obj = ObjWithPrivateAttribute()
|
||||
assert common.get_public_or_private_attr(obj, "attr") == "private"
|
||||
|
||||
|
||||
def test_get_public_or_private_attr_no_attribute():
|
||||
class ObjWithoutAttribute:
|
||||
pass
|
||||
|
||||
obj = ObjWithoutAttribute()
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
common.get_public_or_private_attr(obj, "attr")
|
||||
|
||||
|
||||
def test_new_from_similar():
|
||||
class Object:
|
||||
def __init__(self, foo, bar="bar"):
|
||||
self.foo = foo
|
||||
self._bar = bar
|
||||
|
||||
def __eq__(self, other):
|
||||
foo = self.foo == other.foo
|
||||
bar = self._bar == other._bar
|
||||
return foo and bar
|
||||
|
||||
class PatchedObj(Object):
|
||||
def __init__(self, foo, bar="patched_default_bar", baz=None):
|
||||
super().__init__(foo, bar=bar)
|
||||
self._baz = baz
|
||||
|
||||
obj = Object("foo")
|
||||
new_obj = common.new_from_similar(PatchedObj, obj, ("foo", "bar"), baz="baz")
|
||||
|
||||
assert new_obj is not obj
|
||||
assert isinstance(new_obj, PatchedObj)
|
||||
assert new_obj == obj
|
||||
assert common.get_public_or_private_attr(new_obj, "baz") == "baz"
|
||||
@@ -1,59 +0,0 @@
|
||||
import pytest
|
||||
|
||||
import light_the_torch as ltt
|
||||
from light_the_torch._pip import extract
|
||||
from light_the_torch._pip.common import InternalLTTError
|
||||
|
||||
|
||||
def test_StopAfterPytorchDistsFoundResolver_no_torch(mocker):
|
||||
mocker.patch(
|
||||
"light_the_torch._pip.extract.PatchedResolverBase.__init__", return_value=None
|
||||
)
|
||||
resolver = extract.StopAfterPytorchDistsFoundResolver()
|
||||
resolver._required_pytorch_dists = ["torchaudio", "torchtext", "torchvision"]
|
||||
assert "torch" in resolver.required_pytorch_dists
|
||||
|
||||
|
||||
def test_StopAfterPytorchDistsFoundResolver_torch_compatibility(mocker):
|
||||
mocker.patch(
|
||||
"light_the_torch._pip.extract.PatchedResolverBase.__init__", return_value=None
|
||||
)
|
||||
resolver = extract.StopAfterPytorchDistsFoundResolver()
|
||||
resolver._required_pytorch_dists = ["torchvision==0.7"]
|
||||
assert "torch==1.6.0" in resolver.required_pytorch_dists
|
||||
|
||||
|
||||
def test_extract_pytorch_internal_error(mocker):
|
||||
mocker.patch("light_the_torch._pip.extract.run")
|
||||
|
||||
with pytest.raises(InternalLTTError):
|
||||
ltt.extract_dists(["foo"])
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_extract_dists_ltt():
|
||||
assert ltt.extract_dists(["light-the-torch"]) == []
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_extract_dists_pystiche(subtests):
|
||||
pystiche = "git+https://github.com/pmeier/pystiche@v{}"
|
||||
reqs_and_dists = (
|
||||
(pystiche.format("0.4.0"), {"torch>=1.4.0", "torchvision>=0.5.0"}),
|
||||
(pystiche.format("0.5.0"), {"torch>=1.5.0", "torchvision>=0.6.0"}),
|
||||
)
|
||||
for req, dists in reqs_and_dists:
|
||||
with subtests.test(req=req):
|
||||
assert set(ltt.extract_dists([req])) == dists
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_extract_dists_kornia(subtests):
|
||||
kornia = "kornia=={}"
|
||||
reqs_and_dists = (
|
||||
(kornia.format("0.2.2"), {"torch<=1.4.0,>=1.0.0"}),
|
||||
(kornia.format("0.3.1"), {"torch==1.5.0"}),
|
||||
)
|
||||
for req, dists in reqs_and_dists:
|
||||
with subtests.test(req=req):
|
||||
assert set(ltt.extract_dists([req])) == dists
|
||||
@@ -1,195 +0,0 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
from pip._internal.models.wheel import Wheel
|
||||
from pip._vendor.packaging.version import Version
|
||||
|
||||
import light_the_torch as ltt
|
||||
import light_the_torch.computation_backend as cb
|
||||
from light_the_torch._pip.common import InternalLTTError
|
||||
from light_the_torch._pip.find import maybe_add_option
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_extract_dists(mocker):
|
||||
def patch_extract_dists_(return_value=None):
|
||||
if return_value is None:
|
||||
return_value = []
|
||||
return mocker.patch(
|
||||
"light_the_torch._pip.find.extract_dists", return_value=return_value
|
||||
)
|
||||
|
||||
return patch_extract_dists_
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_run(mocker):
|
||||
def patch_run_():
|
||||
return mocker.patch("light_the_torch._pip.find.run")
|
||||
|
||||
return patch_run_
|
||||
|
||||
|
||||
CHANNELS = ("stable", "lts", "test", "nightly")
|
||||
PLATFORMS = ("linux_x86_64", "macosx_10_9_x86_64", "win_amd64")
|
||||
PLATFORM_MAP = dict(zip(PLATFORMS, ("Linux", "Darwin", "Windows")))
|
||||
|
||||
|
||||
SUPPORTED_PYTHON_VERSIONS = {
|
||||
Version("11.1"): tuple(f"3.{minor}" for minor in (6, 7, 8, 9)),
|
||||
Version("11.0"): tuple(f"3.{minor}" for minor in (6, 7, 8, 9)),
|
||||
Version("10.2"): tuple(f"3.{minor}" for minor in (6, 7, 8, 9)),
|
||||
Version("10.1"): tuple(f"3.{minor}" for minor in (6, 7, 8, 9)),
|
||||
Version("10.0"): tuple(f"3.{minor}" for minor in (6, 7, 8)),
|
||||
Version("9.2"): tuple(f"3.{minor}" for minor in (6, 7, 8, 9)),
|
||||
Version("9.1"): tuple(f"3.{minor}" for minor in (6,)),
|
||||
Version("9.0"): tuple(f"3.{minor}" for minor in (6, 7)),
|
||||
Version("8.0"): tuple(f"3.{minor}" for minor in (6, 7)),
|
||||
Version("7.5"): tuple(f"3.{minor}" for minor in (6,)),
|
||||
}
|
||||
PYTHON_VERSIONS = set(itertools.chain(*SUPPORTED_PYTHON_VERSIONS.values()))
|
||||
|
||||
|
||||
def test_maybe_add_option_already_set():
|
||||
args = ["--foo", "bar"]
|
||||
assert maybe_add_option(args, "--foo",) == args
|
||||
assert maybe_add_option(args, "-f", aliases=("--foo",)) == args
|
||||
|
||||
|
||||
def test_find_links_internal_error(patch_extract_dists, patch_run):
|
||||
patch_extract_dists()
|
||||
patch_run()
|
||||
|
||||
with pytest.raises(InternalLTTError):
|
||||
ltt.find_links([])
|
||||
|
||||
|
||||
def test_find_links_computation_backend_detect(
|
||||
mocker, patch_extract_dists, patch_run, generic_backend
|
||||
):
|
||||
computation_backends = {generic_backend}
|
||||
mocker.patch(
|
||||
"light_the_torch.computation_backend.detect_compatible_computation_backends",
|
||||
return_value=computation_backends,
|
||||
)
|
||||
|
||||
patch_extract_dists()
|
||||
run = patch_run()
|
||||
|
||||
with pytest.raises(InternalLTTError):
|
||||
ltt.find_links([], computation_backends=None)
|
||||
|
||||
args, _ = run.call_args
|
||||
cmd = args[0]
|
||||
assert cmd.computation_backends == computation_backends
|
||||
|
||||
|
||||
def test_find_links_unknown_channel():
|
||||
with pytest.raises(ValueError):
|
||||
ltt.find_links([], channel="channel")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("platform", PLATFORMS)
|
||||
def test_find_links_platform(patch_extract_dists, patch_run, platform):
|
||||
patch_extract_dists()
|
||||
run = patch_run()
|
||||
|
||||
with pytest.raises(InternalLTTError):
|
||||
ltt.find_links([], platform=platform)
|
||||
|
||||
args, _ = run.call_args
|
||||
options = args[2]
|
||||
assert options.platform == platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize("python_version", PYTHON_VERSIONS)
|
||||
def test_find_links_python_version(patch_extract_dists, patch_run, python_version):
|
||||
patch_extract_dists()
|
||||
run = patch_run()
|
||||
|
||||
python_version_tuple = tuple(int(v) for v in python_version.split("."))
|
||||
|
||||
with pytest.raises(InternalLTTError):
|
||||
ltt.find_links([], python_version=python_version)
|
||||
|
||||
args, _ = run.call_args
|
||||
options = args[2]
|
||||
assert options.python_version == python_version_tuple
|
||||
|
||||
|
||||
def wheel_properties():
|
||||
params = []
|
||||
for platform in PLATFORMS:
|
||||
params.extend(
|
||||
[
|
||||
(platform, cb.CPUBackend(), python_version)
|
||||
for python_version in PYTHON_VERSIONS
|
||||
]
|
||||
)
|
||||
|
||||
system = PLATFORM_MAP[platform]
|
||||
cuda_versions = cb._MINIMUM_DRIVER_VERSIONS.get(system, {}).keys()
|
||||
if not cuda_versions:
|
||||
continue
|
||||
|
||||
params.extend(
|
||||
[
|
||||
(
|
||||
platform,
|
||||
cb.CUDABackend(cuda_version.major, cuda_version.minor),
|
||||
python_version,
|
||||
)
|
||||
for cuda_version in cuda_versions
|
||||
for python_version in SUPPORTED_PYTHON_VERSIONS[cuda_version]
|
||||
if not (
|
||||
platform == "win_amd64"
|
||||
and (
|
||||
(cuda_version == Version("7.5") and python_version == "3.6")
|
||||
or (cuda_version == Version("9.2") and python_version == "3.9")
|
||||
or (cuda_version == Version("10.0") and python_version == "3.8")
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return pytest.mark.parametrize(
|
||||
("platform", "computation_backend", "python_version"), params, ids=str,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
"pytorch_dist", ["torch", "torchaudio", "torchtext", "torchvision"]
|
||||
)
|
||||
@wheel_properties()
|
||||
def test_find_links_stable_smoke(
|
||||
pytorch_dist, platform, computation_backend, python_version
|
||||
):
|
||||
assert ltt.find_links(
|
||||
[pytorch_dist],
|
||||
computation_backends=computation_backend,
|
||||
platform=platform,
|
||||
python_version=python_version,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("channel", CHANNELS)
|
||||
def test_find_links_channel_smoke(channel):
|
||||
assert ltt.find_links(
|
||||
["torch"], computation_backends={cb.CPUBackend()}, channel=channel
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("python_version", PYTHON_VERSIONS)
|
||||
def test_mac_torch_ge_1_0_0(patch_extract_dists, patch_run, python_version):
|
||||
# See https://github.com/pmeier/light-the-torch/issues/34
|
||||
dists = ["torch"]
|
||||
patch_extract_dists(return_value=dists)
|
||||
|
||||
links = ltt.find_links(
|
||||
dists, python_version=python_version, platform="macosx_10_9_x86_64"
|
||||
)
|
||||
version = Version(Wheel(links[0]).version)
|
||||
|
||||
assert version >= Version("1.0.0")
|
||||
@@ -1,17 +0,0 @@
|
||||
import pytest
|
||||
|
||||
import light_the_torch.computation_backend as cb
|
||||
|
||||
|
||||
class GenericComputationBackend(cb.ComputationBackend):
|
||||
@property
|
||||
def local_specifier(self):
|
||||
return "generic"
|
||||
|
||||
def __lt__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generic_backend():
|
||||
return GenericComputationBackend()
|
||||
@@ -1,34 +0,0 @@
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
import light_the_torch as package_under_test
|
||||
|
||||
|
||||
def test_importability(subtests):
|
||||
def is_private(name):
|
||||
return name.rsplit(".", 1)[-1].startswith("_")
|
||||
|
||||
def onerror(name):
|
||||
if is_private(name):
|
||||
return
|
||||
|
||||
with subtests.test(name=name):
|
||||
raise
|
||||
|
||||
for finder, name, is_package in pkgutil.walk_packages(
|
||||
path=package_under_test.__path__,
|
||||
prefix=f"{package_under_test.__name__}.",
|
||||
onerror=onerror,
|
||||
):
|
||||
if is_private(name):
|
||||
continue
|
||||
|
||||
if not is_package:
|
||||
try:
|
||||
importlib.import_module(name)
|
||||
except Exception:
|
||||
onerror(name)
|
||||
|
||||
|
||||
def test_version_availability():
|
||||
assert isinstance(package_under_test.__version__, str)
|
||||
60
tox.ini
60
tox.ini
@@ -1,60 +0,0 @@
|
||||
[tox]
|
||||
;See link below for available options
|
||||
;https://tox.readthedocs.io/en/latest/config.html
|
||||
|
||||
isolated_build = True
|
||||
envlist = py{36, 37, 38}
|
||||
skip_missing_interpreters = True
|
||||
|
||||
[testenv]
|
||||
deps =
|
||||
pytest
|
||||
pytest-subtests
|
||||
pytest-mock
|
||||
pytest-cov
|
||||
commands =
|
||||
pytest \
|
||||
-c pytest.ini \
|
||||
--cov=light_the_torch \
|
||||
--cov-report=xml \
|
||||
--cov-config=.coveragerc \
|
||||
{posargs}
|
||||
|
||||
[testenv:format]
|
||||
requires =
|
||||
pre-commit
|
||||
whitelist_externals =
|
||||
pre-commit
|
||||
skip_install = True
|
||||
deps =
|
||||
commands =
|
||||
pre-commit run --all-files
|
||||
|
||||
[testenv:lint]
|
||||
whitelist_externals =
|
||||
pre-commit
|
||||
requires =
|
||||
pre-commit
|
||||
deps =
|
||||
flake8 >= 3.8
|
||||
mypy
|
||||
commands =
|
||||
pre-commit run --all-files
|
||||
flake8 \
|
||||
--config=.flake8
|
||||
mypy \
|
||||
--config-file=mypy.ini
|
||||
|
||||
[testenv:publishable]
|
||||
whitelist_externals =
|
||||
rm
|
||||
skip_install = True
|
||||
deps =
|
||||
check-wheel-contents
|
||||
pep517
|
||||
twine
|
||||
commands =
|
||||
rm -rf build dist light_the_torch.egg-info
|
||||
python -m pep517.build --source --binary .
|
||||
twine check dist/*
|
||||
check-wheel-contents dist
|
||||
Reference in New Issue
Block a user