rewrite project (#60)

* replace tox with doit (#42)

* replace tox with doit

* fix dev env setup

* fix dev env setup

* fix dev env setup

* apply fixes to all workflows

* try fix path

* try pip cache

* fix output location

* fix quotes

* fix quotes

* move path file extraction to script

* setup cache

* remove debug step

* add pre-commit caching

* rewrite ltt to wrap pip directly rather than use its internals

* add publish task and fix workflow

* fix publish

* update pre-commit hooks

* remove mypy

* cleanup

* add setup task

* fix some bugs

* disable output capturing on publish task

* fix publish task

* install wheel package in CI dev setup

* Rewrite smoke and computation backend tests (#44)

* fix publish task

* install wheel package in CI dev setup

* fix computation backend tests

* fix smoke tests

* delete other tests for now

* fix test workflow

* add CLI tests after rewrite (#46)

* rewrite README (#45)

* rewrite README

* update how does it work

* refactor the why section

* fix typo

Co-authored-by: James Butler <jamesobutler@users.noreply.github.com>

Co-authored-by: James Butler <jamesobutler@users.noreply.github.com>

* try appdirs to find pip cache path (#48)

* try appdirs to find pip cache path

* print

* fix app_author

* cleanup

* fix

* cleanup

* more cleanup

* refactor dodo (#47)

* refactor dodo

* fix test CI

* address review comments

Co-authored-by: Tony Fast <tony.fast@gmail.com>

Co-authored-by: Tony Fast <tony.fast@gmail.com>

* add default tasks (#49)

* fix module entrypoint (#50)

* fix candidate selection (#52)

* remove test task passthrough (#53)

* disable ROCm wheels (#54)

* fix ROCm deselection (#56)

* relax ROCm deselection even further (#57)

* add naive torch install test (#58)

* add naive torch install test

* trigger CI

* add cpuonly check

* don't fail fast

* add pytorch channel to the test matrix

* don't test LTS channel on 3.10

* add check without specifying the channel

* Revert "add check without specifying the channel"

This reverts commit 0842abf50f.

* use extra index rathe than find links for link patching (#59)

* update README

* fix Windows install

Co-authored-by: James Butler <jamesobutler@users.noreply.github.com>
Co-authored-by: Tony Fast <tony.fast@gmail.com>
This commit is contained in:
Philip Meier
2022-04-04 20:36:17 +02:00
committed by GitHub
parent 7467b915bb
commit 40471ca049
53 changed files with 1221 additions and 2523 deletions

27
.flake8
View File

@@ -1,22 +1,19 @@
[flake8]
# See link below for available options
# https://flake8.pycqa.org/en/latest/user/options.html#options-and-their-descriptions
# Move this to pyproject.toml as soon as it is supported.
# See https://gitlab.com/pycqa/flake8/issues/428
# See https://flake8.pycqa.org/en/latest/user/options.html#options-and-their-descriptions
# for available options
exclude =
.git,
.venv,
.eggs,
.mypy_cache,
.pytest_cache,
.tox,
__pycache__,
**/__pycache__,
*.pyc,
ignore = E203, E501, W503
max-line-length = 88
.venv,
build,
# See https://www.flake8rules.com/ for a list of all builtin error codes
ignore =
E203, E501, W503
per-file-ignores =
__init__.py: F401, F403, F405
conftest.py: F401, F403, F405
__init__.py: F401
conftest.py: F401
show_source = True
statistics = True

View File

@@ -0,0 +1,46 @@
name: setup-dev-env
description: "Setup development environment"
inputs:
python-version:
default: "3.7"
runs:
using: composite
steps:
- name: Set up python
uses: actions/setup-python@v2
with:
python-version: ${{ inputs.python-version }}
- name: Get pip cache path
id: get-cache-path
shell: bash
run: |
pip install appdirs
CACHE_PATH=`python -c "import appdirs; print(appdirs.user_cache_dir('pip', appauthor=False))"`
echo "::set-output name=path::$CACHE_PATH"
- name: Restore pip cache
uses: actions/cache@v2
with:
path: ${{ steps.get-cache-path.outputs.path }}
key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Upgrade or install system packages
shell: bash
run: python -m pip install --upgrade pip wheel
- name: Checkout repository
uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install development requirements and project
shell: bash
run: |
pip install doit
doit install

55
.github/workflows/install.yml vendored Normal file
View File

@@ -0,0 +1,55 @@
name: install
on:
push:
branches:
- main
- releases/*
pull_request:
jobs:
torch_cpu:
strategy:
matrix:
os:
- ubuntu-latest
- windows-latest
- macos-latest
python-version:
- "3.7"
- "3.8"
- "3.9"
- "3.10"
pytorch-channel:
- stable
- test
- nightly
- lts
exclude:
- os: macos-latest
pytorch-channel: lts
- python-version: "3.10"
pytorch-channel: lts
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- name: Checkout repository
uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup development environment
uses: ./.github/actions/setup-dev-env
with:
python-version: ${{ matrix.python-version }}
- name: Install torch
run: ltt install --cpuonly --pytorch-channel=${{ matrix.pytorch-channel }} torch
- name: Check if CPU only
run:
python -c "import sys, torch; sys.exit(hasattr(torch._C,
'_cuda_getDeviceCount'))"

View File

@@ -7,39 +7,27 @@ on:
- releases/*
pull_request:
paths:
- "**.py"
- "pyproject.toml"
- ".pre-commit-config.yaml"
- ".flake8"
- "mypy.ini"
- "tox.ini"
- "requirements-dev.txt"
- ".github/workflows/lint.yml"
jobs:
check:
runs-on: ubuntu-latest
steps:
- name: Set up python
uses: actions/setup-python@v2
with:
python-version: "3.6"
- name: Upgrade pip
run: python -m pip install --upgrade pip
- name: Upgrade or install additional system packages
run: pip install --upgrade setuptools virtualenv wheel
- name: Checkout repository
uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install dev requirements
run: pip install -r requirements-dev.txt
- name: Setup development environment
uses: ./.github/actions/setup-dev-env
- name: Restore pre-commit cache
uses: actions/cache@v2
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
restore-keys: |
pre-commit-
- name: Run lint
run: tox -e lint
run: doit lint

View File

@@ -2,38 +2,27 @@ name: publish
on:
release:
types: [created]
types:
- published
jobs:
pypi:
runs-on: ubuntu-latest
steps:
- name: Set up python
uses: actions/setup-python@v2
with:
python-version: "3.6"
- name: Upgrade pip
run: python -m pip install --upgrade pip
- name: Upgrade or install additional system packages
run: pip install --upgrade setuptools virtualenv wheel
- name: Checkout repository
uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install pep517 and twine
run: pip install pep517 twine
- name: Setup development environment
uses: ./.github/actions/setup-dev-env
with:
python-version: "3.7"
- name: Build source and binary
run: python -m pep517.build --source --binary .
- name: Upload to PyPI
- name: Publish to PyPI
env:
TWINE_REPOSITORY: pypi
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: twine upload dist/*
run: doit publish

View File

@@ -8,41 +8,31 @@ on:
pull_request:
paths:
- ".github/workflows/publishable.yml"
- ".github/actions/setup-dev-env/**"
- "light_the_torch/**"
- "pyproject.toml"
- "setup.cfg"
- ".gitignore"
- "CONTRIBUTING.rst"
- "dodo.py"
- "LICENSE"
- "MANIFEST.in"
- "pyproject.toml"
- "README.rst"
- "tox.ini"
- "requirements-dev.txt"
- ".github/workflows/publishable.yml"
- "setup.cfg"
jobs:
pypi:
runs-on: ubuntu-latest
steps:
- name: Set up python
uses: actions/setup-python@v2
with:
python-version: "3.6"
- name: Upgrade pip
run: python -m pip install --upgrade pip
- name: Upgrade or install additional system packages
run: pip install --upgrade setuptools virtualenv wheel
- name: Checkout repository
uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install dev requirements
run: pip install -r requirements-dev.txt
- name: Setup development environment
uses: ./.github/actions/setup-dev-env
- name: Run unit tests
run: tox -e publishable
- name: Check if publishable
run: doit publishable

View File

@@ -8,14 +8,16 @@ on:
pull_request:
paths:
- ".github/workflows/tests.yml"
- ".github/actions/setup-dev-env/**"
- "light_the_torch/**"
- "tests/**"
- "pytest.ini"
- "tox.ini"
- ".coveragerc"
- "codecov.yml"
- "dodo.py"
- "pytest.ini"
- "requirements-dev.txt"
- ".github/workflows/tests.yml"
- "setup.cfg"
schedule:
- cron: "0 4 * * *"
@@ -24,62 +26,37 @@ jobs:
unit:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python: ['3.6', '3.7', '3.8']
fail-fast: true
os:
- ubuntu-latest
- windows-latest
- macos-latest
python-version:
- "3.7"
- "3.8"
- "3.9"
- "3.10"
runs-on: ${{ matrix.os }}
env:
OS: ${{ matrix.os }}
PYTHON: ${{ matrix.python }}
PYTHON_VERSION: ${{ matrix.python-version }}
steps:
- name: Set up python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python }}
- name: Upgrade pip
run: python -m pip install --upgrade pip
- name: Upgrade or install additional system packages
run: pip install --upgrade setuptools virtualenv wheel
- name: Checkout repository
uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Install dev requirements
run: pip install -r requirements-dev.txt
- name: Setup development environment
uses: ./.github/actions/setup-dev-env
with:
python-version: ${{ matrix.python-version }}
- name: Run unit tests
run: tox -e py -- --skip-large-download
run: doit test
- name: Upload coverage
uses: codecov/codecov-action@v1.0.7
uses: codecov/codecov-action@v2.1.0
with:
env_vars: OS,PYTHON
cli:
runs-on: ubuntu-latest
steps:
- name: Set up python
uses: actions/setup-python@v2
with:
python-version: "3.6"
- name: Upgrade and install additional system packages
run: pip install --upgrade pip setuptools virtualenv wheel
- name: Checkout repository
uses: actions/checkout@v2
- name: Install project
run: pip install .
- name: Test CLI
run: |
ltt --version
ltt --help
flags: unit
env_vars: OS,PYTHON_VERSION

2
.gitignore vendored
View File

@@ -1,5 +1,7 @@
light_the_torch/_version.py
.doit.db*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@@ -1,20 +1,26 @@
repos:
- repo: https://github.com/timothycrosley/isort
rev: "4.3.21"
hooks:
- id: isort
args: [--settings-path=pyproject.toml, --filter-files]
additional_dependencies: [toml]
- repo: https://github.com/psf/black
rev: 19.10b0
hooks:
- id: black
args: [--config=pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.1.0
rev: v4.1.0
hooks:
- id: check-added-large-files
- id: check-docstring-first
- id: check-toml
- id: check-yaml
- id: trailing-whitespace
- id: mixed-line-ending
args:
- --fix=lf
- id: end-of-file-fixer
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.5.1
hooks:
- id: prettier
types_or:
- markdown
- toml
- yaml
- repo: https://github.com/omnilib/ufmt
rev: v1.3.2
hooks:
- id: ufmt
additional_dependencies:
- black == 22.1.0
- usort == 1.0.2

2
.prettierrc.yaml Normal file
View File

@@ -0,0 +1,2 @@
proseWrap: always
printWidth: 88

View File

@@ -58,7 +58,7 @@ To run the full lint check locally run
Tests
-----
``light-the-torch`` uses `pytest <https://docs.pytest.org/en/stable/>`_ to run the test
``light-the-torch`` uses `pytest <https://docs.pytest.org/en/stable/>`_ to run the test
suite. You can run it locally with
.. code-block:: sh

View File

@@ -6,8 +6,8 @@ exclude .flake8
exclude .gitignore
exclude .pre-commit-config.yaml
exclude codecov.yml
exclude dodo.py
exclude MANIFEST.in
exclude mypy.ini
exclude pytest.ini
exclude requirements-dev.txt
exclude tox.ini

126
README.md Normal file
View File

@@ -0,0 +1,126 @@
# `light-the-torch`
[![BSD-3-Clause License](https://img.shields.io/github/license/pmeier/light-the-torch)](https://opensource.org/licenses/BSD-3-Clause)
[![Project Status: WIP](https://www.repostatus.org/badges/latest/wip.svg)](https://www.repostatus.org/#wip)
[![Code coverage via codecov.io](https://codecov.io/gh/pmeier/light-the-torch/branch/main/graph/badge.svg)](https://codecov.io/gh/pmeier/light-the-torch)
`light-the-torch` is a small utility that wraps `pip` to ease the installation process
for PyTorch distributions and third-party packages that depend on them. It auto-detects
compatible CUDA versions from the local setup and installs the correct PyTorch binaries
without user interference.
- [Why do I need it?](#why-do-i-need-it)
- [How do I install it?](#how-do-i-install-it)
- [How do I use it?](#how-do-i-use-it)
- [How does it work?](#how-does-it-work)
## Why do I need it?
PyTorch distributions are fully `pip install`'able, but PyPI, the default `pip` search
index, has some limitations:
1. PyPI regularly only allows binaries up to a size of
[approximately 60 MB](https://github.com/pypa/packaging-problems/issues/86). One can
[request a file size limit increase](https://pypi.org/help/#file-size-limit) (and the
PyTorch team probably does that for every release), but it is still not enough:
although PyTorch has pre-built binaries for Windows with CUDA, they cannot be
installed through PyPI due to their size.
2. PyTorch uses local version specifiers to indicate for which computation backend the
binary was compiled, for example `torch==1.11.0+cpu`. Unfortunately, local specifiers
are not allowed on PyPI. Thus, only the binaries compiled with one CUDA version are
uploaded without an indication of the CUDA version. If you do not have a CUDA capable
GPU, downloading this is only a waste of bandwidth and disk capacity. If on the other
hand your NVIDIA driver version simply doesn't support the CUDA version the binary
was compiled with, you can't use any of the GPU features.
To overcome this, PyTorch also hosts _all_ binaries
[themselves](https://download.pytorch.org/whl/torch_stable.html). To access them, you
can still use `pip install` them, but some
[additional options](https://pytorch.org/get-started/locally/) are needed:
```shell
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
```
While this is certainly an improvement, it still has a few downsides:
1. You need to know what computation backend, e.g. CUDA 11.3 (`cu113`), is supported on
your local machine. This can be quite challenging for new users and at least tedious
for more experienced ones.
2. Besides the stable binaries, PyTorch also offers nightly, test, and long-time support
(LTS) ones. To install them, you need a different `--extra-index-url` value for each.
For the nightly and test channel you also need to supply the `--pre` option.
3. If you want to install any package hosted on PyPI that depends on PyTorch, you always
also specify all PyTorch distributions to install. Otherwise, the `--extra-index-url`
flag is ignored and the PyTorch distributions hosted on PyPI will be installed.
If any of these points don't sound appealing to you, and you just want to have the same
user experience as `pip install` for PyTorch distributions, `light-the-torch` was made
for you.
## How do I install it?
Installing `light-the-torch` is as easy as
```shell
pip install --pre light-the-torch
```
Since it depends on `pip` and it might be upgraded during installation,
[Windows users](https://pip.pypa.io/en/stable/installation/#upgrading-pip) should
install it with
```shell
py -m pip install --pre light-the-torch
```
## How do I use it?
After `light-the-torch` is installed you can use its CLI interface `ltt` as drop-in
replacement for `pip`:
```shell
ltt install torch
```
In fact, `ltt` is `pip` with a few added options:
- By default, `ltt` uses the local NVIDIA driver version to select the correct binary
for you. You can pass the `--pytorch-computation-backend` option to manually specify
the computation backend you want to use:
```shell
ltt install --pytorch-computation-backend=cu102 torch
```
- By default, `ltt` installs stable PyTorch binaries. To install binaries from nightly,
test, or LTS channels pass the `--pytorch-channel` option:
```shell
ltt install --pytorch-channel=nightly torch
```
If `--pytorch-channel` is not passed, using `pip`'s builtin `--pre` option will
install PyTorch test binaries.
Of course you are not limited to install only PyTorch distributions. Everything shown
above also works if you install packages that depend on PyTorch:
```shell
ltt install --pytorch-computation-backend=cpu --pytorch-channel=nightly pystiche
```
## How does it work?
The authors of `pip` **do not condone** the use of `pip` internals as they might break
without warning. As a results of this, `pip` has no capability for plugins to hook into
specific tasks.
`light-the-torch` works by monkey-patching `pip` internals at runtime:
- While searching for a download link for a PyTorch distribution, `light-the-torch`
replaces the default search index with an official PyTorch download link. This is
equivalent to calling `pip install` with the `--extra-index-url` option only for
PyTorch distributions.
- While evaluating possible PyTorch installation candidates, `light-the-torch` culls
binaries incompatible with the hardware.

View File

@@ -1,333 +0,0 @@
light-the-torch
===============
.. start-badges
.. list-table::
:stub-columns: 1
* - package
- |license| |status|
* - code
- |black| |mypy| |lint|
* - tests
- |tests| |coverage|
.. end-badges
``light-the-torch`` offers a small CLI (and
`tox plugin <https://github.com/pmeier/tox-ltt>`_) based on ``pip`` to install PyTorch
distributions from the stable releases. Similar to the platform and Python version, the
computation backend is auto-detected from the available hardware preferring CUDA over
CPU.
Motivation
==========
With each release of a PyTorch distribution (``torch``, ``torchvision``,
``torchaudio``, ``torchtext``) the wheels are published for combinations of different
computation backends (CPU, CUDA), platforms, and Python versions. Unfortunately, a
differentation based on the computation backend is not supported by
`PEP 440 <https://www.python.org/dev/peps/pep-0440/>`_ . As a workaround the
computation backend is added as a local specifier. For example
.. code-block:: sh
torch==1.5.1+cpu
Due to this restriction only the wheels of the latest CUDA release are uploaded to
`PyPI <https://pypi.org/search/?q=torch>`_ and thus easily ``pip install`` able. For
other CUDA versions or the installation without CUDA support, one has to resort to
manual version specification:
.. code-block:: sh
pip install -f https://download.pytorch.org/whl/torch_stable.html torch==1.5.1+cu101
This is especially frustrating if one wants to install packages that depend on one or
several PyTorch distributions: for each package the required PyTorch distributions have
to be manually tracked down, resolved, and installed before the other requirements can
be installed.
``light-the-torch`` was developed to overcome this.
Installation
============
The latest **published** version can be installed with
.. code-block:: sh
pip install light-the-torch
The latest, potentially unstable **development** version can be installed with
.. code-block::
pip install git+https://github.com/pmeier/light-the-torch
Usage
=====
.. note::
The following examples were run on a linux machine with Python 3.6 and CUDA 10.1. The
distributions hosted on PyPI were built with CUDA 10.2.
CLI
---
The CLI of ``light-the-torch`` is invoked with its shorthand ``ltt``
.. code-block:: sh
$ ltt --help
usage: ltt [-h] [-V] {install,extract,find} ...
optional arguments:
-h, --help show this help message and exit
-V, --version show light-the-torch version and path and exit
subcommands:
{install,extract,find}
``ltt install``
^^^^^^^^^^^^^^^
.. code-block:: sh
$ ltt install --help
usage: ltt install [-h] [--force-cpu] [--pytorch-only]
[--install-cmd INSTALL_CMD] [--verbose]
[args [args ...]]
Install PyTorch distributions from the stable releases. The computation
backend is auto-detected from the available hardware preferring CUDA over CPU.
positional arguments:
args arguments of 'pip install'. Optional arguments have to
be seperated by '--'
optional arguments:
-h, --help show this help message and exit
--force-cpu disable computation backend auto-detection and use CPU
instead
--pytorch-only install only PyTorch distributions
--install-cmd INSTALL_CMD
installation command for the PyTorch distributions and
additional packages. Defaults to 'python -m pip
install {packages}'
--verbose print more output to STDOUT. For fine control use -v /
--verbose and -q / --quiet of the 'pip install'
options
``ltt install`` is a drop-in replacement for ``pip install`` without worrying about the
computation backend:
.. code-block:: sh
$ ltt install torch torchvision
[...]
Successfully installed future-0.18.2 numpy-1.19.0 pillow-7.2.0 torch-1.5.1+cu101 torchvision-0.6.1+cu101
[...]
``ltt install`` is also able to handle packages that depend on PyTorch distributions:
.. code-block:: sh
$ ltt install kornia
[...]
Successfully installed future-0.18.2 numpy-1.19.0 torch-1.5.0+cu101
[...]
Successfully installed kornia-0.3.1
``ltt extract``
^^^^^^^^^^^^^^^
.. code-block:: sh
$ ltt extract --help
usage: ltt extract [-h] [--verbose] [args [args ...]]
Extract required PyTorch distributions
positional arguments:
args arguments of 'pip install'. Optional arguments have to be
seperated by '--'
optional arguments:
-h, --help show this help message and exit
--verbose print more output to STDOUT. For fine control use -v / --verbose
and -q / --quiet of the 'pip install' options
``ltt extract`` extracts the required PyTorch distributions out of packages:
.. code-block:: sh
$ ltt extract kornia
torch==1.5.0
.. warning::
Internally, ``light-the-torch`` uses the ``pip`` resolver which, as of now,
unfortunately allows conflicting dependencies:
.. code-block:: sh
$ ltt extract kornia "torch>1.5"
torch>1.5
``ltt find``
^^^^^^^^^^^^
.. code-block:: sh
$ ltt find --help
usage: ltt find [-h] [--computation-backend COMPUTATION_BACKEND]
[--platform PLATFORM] [--python-version PYTHON_VERSION]
[--verbose]
[args [args ...]]
Find wheel links for the required PyTorch distributions
positional arguments:
args arguments of 'pip install'. Optional arguments have to
be seperated by '--'
optional arguments:
-h, --help show this help message and exit
--computation-backend COMPUTATION_BACKEND
Only use wheels compatible with COMPUTATION_BACKEND,
for example 'cu102' or 'cpu'. Defaults to the
computation backend of the running system, preferring
CUDA over CPU.
--platform PLATFORM Only use wheels compatible with <platform>. Defaults
to the platform of the running system.
--python-version PYTHON_VERSION
The Python interpreter version to use for wheel and
"Requires-Python" compatibility checks. Defaults to a
version derived from the running interpreter. The
version can be specified using up to three dot-
separated integers (e.g. "3" for 3.0.0, "3.7" for
3.7.0, or "3.7.3"). A major-minor version can also be
given as a string without dots (e.g. "37" for 3.7.0).
--verbose print more output to STDOUT. For fine control use -v /
--verbose and -q / --quiet of the 'pip install'
options
``ltt find`` finds the links to the wheels of the required PyTorch distributions:
.. code-block:: sh
$ ltt find torchaudio > requirements.txt
$ cat requirements.txt
https://download.pytorch.org/whl/cu101/torch-1.5.1%2Bcu101-cp36-cp36m-linux_x86_64.whl
https://download.pytorch.org/whl/torchaudio-0.5.1-cp36-cp36m-linux_x86_64.whl
The ``--computation-backend``, ``--platform``, and ``python-version`` options can be
used pin wheel properties instead of auto-detecting them:
.. code-block:: sh
$ ltt find \
--computation-backend cu92 \
--platform win_amd64 \
--python-version 3.7 \
torchtext
https://download.pytorch.org/whl/cu92/torch-1.5.1%2Bcu92-cp37-cp37m-win_amd64.whl
https://download.pytorch.org/whl/torchtext-0.6.0-py3-none-any.whl
Python
------
``light-the-torch`` exposes two functions that can be used from Python:
.. code-block:: python
import light_the_torch as ltt
help(ltt.extract_dists)
.. code-block::
Help on function extract_dists in module light_the_torch._pip.extract:
extract_dists(pip_install_args:List[str], verbose:bool=False) -> List[str]
Extract direct or indirect required PyTorch distributions.
Args:
pip_install_args: Arguments passed to ``pip install`` that will be searched for
required PyTorch distributions
verbose: If ``True``, print additional information to STDOUT.
Returns:
Resolved required PyTorch distributions.
.. code-block:: python
import light_the_torch as ltt
help(ltt.find_links)
.. code-block::
Help on function find_links in module light_the_torch._pip.find:
find_links(pip_install_args:List[str], computation_backend:Union[str, light_the_torch.computation_backend.ComputationBackend, NoneType]=None, platform:Union[str, NoneType]=None, python_version:Union[str, NoneType]=None, verbose:bool=False) -> List[str]
Find wheel links for direct or indirect PyTorch distributions with given
properties.
Args:
pip_install_args: Arguments passed to ``pip install`` that will be searched for
required PyTorch distributions
computation_backend: Computation backend, for example ``"cpu"`` or ``"cu102"``.
Defaults to the available hardware of the running system preferring CUDA
over CPU.
platform: Platform, for example ``"linux_x86_64"`` or ``"win_amd64"``. Defaults
to the platform of the running system.
python_version: Python version, for example ``"3"`` or ``"3.7"``. Defaults to
the version of the running interpreter.
verbose: If ``True``, print additional information to STDOUT.
Returns:
Wheel links with given properties for all required PyTorch distributions.
.. note::
Optional arguments for ``pip install`` have to be passed after a ``--`` seperator.
.. |license|
image:: https://img.shields.io/badge/License-BSD%203--Clause-blue.svg
:target: https://opensource.org/licenses/BSD-3-Clause
:alt: License
.. |status|
image:: https://www.repostatus.org/badges/latest/wip.svg
:alt: Project Status: WIP
:target: https://www.repostatus.org/#wip
.. |black|
image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black
:alt: black
.. |mypy|
image:: http://www.mypy-lang.org/static/mypy_badge.svg
:target: http://mypy-lang.org/
:alt: mypy
.. |lint|
image:: https://github.com/pmeier/light-the-torch/workflows/lint/badge.svg
:target: https://github.com/pmeier/light-the-torch/actions?query=workflow%3Alint+branch%3Amain
:alt: Lint status via GitHub Actions
.. |tests|
image:: https://github.com/pmeier/light-the-torch/workflows/tests/badge.svg
:target: https://github.com/pmeier/light-the-torch/actions?query=workflow%3Atests+branch%3Amain
:alt: Test status via GitHub Actions
.. |coverage|
image:: https://codecov.io/gh/pmeier/light-the-torch/branch/main/graph/badge.svg
:target: https://codecov.io/gh/pmeier/light-the-torch
:alt: Test coverage via codecov.io

144
dodo.py Normal file
View File

@@ -0,0 +1,144 @@
import os
import pathlib
import shlex
from doit.action import CmdAction
HERE = pathlib.Path(__file__).parent
PACKAGE_NAME = "light_the_torch"
CI = os.environ.get("CI") == "1"
DOIT_CONFIG = dict(
verbosity=2,
backend="json",
default_tasks=[
"lint",
"test",
"publishable",
],
)
def do(*cmd):
if len(cmd) == 1:
cmd = cmd[0]
if isinstance(cmd, str):
cmd = shlex.split(cmd)
return CmdAction(cmd, shell=False, cwd=HERE)
def _install_dev_requirements(pip="python -m pip"):
return f"{pip} install -r requirements-dev.txt"
def _install_project(pip="python -m pip"):
return f"{pip} install -e ."
def task_install():
"""Installs all development requirements and light-the-torch in development mode"""
yield dict(
name="dev",
file_dep=[HERE / "requirements-dev.txt"],
actions=[do(_install_dev_requirements())],
)
yield dict(
name="project",
actions=[
do(_install_project()),
],
)
def task_setup():
"""Sets up a development environment for light-the-torch"""
dev_env = HERE / ".venv"
pip = dev_env / "bin" / "pip"
return dict(
actions=[
do(f"virtualenv {dev_env} --prompt='(light-the-torch-dev) '"),
do(_install_dev_requirements(pip)),
do(_install_project(pip)),
lambda: print(
f"run `source {dev_env / 'bin' / 'activate'}` the virtual environment"
),
],
clean=[do(f"rm -rf {dev_env}")],
uptodate=[lambda: dev_env.exists()],
)
def task_format():
"""Auto-formats all project files"""
return dict(
actions=[
do("pre-commit run --all-files"),
]
)
def task_lint():
"""Lints all project files"""
return dict(
actions=[
do("flake8 --config=.flake8"),
]
)
def task_test():
"""Runs the test suite"""
return dict(
actions=[do(f"pytest -c pytest.ini --cov-report={'xml' if CI else 'term'}")],
)
def task_build():
"""Builds the source distribution and wheel of light-the-torch"""
return dict(
actions=[
do("python -m build ."),
],
clean=[
do(f"rm -rf build dist {PACKAGE_NAME}.egg-info"),
],
)
def task_publishable():
"""Checks if metadata is correct"""
yield dict(
name="twine",
actions=[
# We need the lambda here to lazily glob the files in dist/*, since they
# are only created by the build task rather than when this task is
# created.
do(lambda: ["twine", "check", *list((HERE / "dist").glob("*"))]),
],
task_dep=["build"],
)
yield dict(
name="check-wheel-contents",
actions=[
do("check-wheel-contents dist"),
],
task_dep=["build"],
)
def task_publish():
"""Publishes light-the-torch to PyPI"""
return dict(
# We need the lambda here to lazily glob the files in dist/*, since they are
# only created by the build task rather than when this task is created.
actions=[
do(lambda: ["twine", "upload", *list((HERE / "dist").glob("*"))]),
],
task_dep=[
"lint",
"test",
"publishable",
],
)

View File

@@ -1,3 +1,4 @@
from ._version import version as __version__ # type: ignore[import]
from ._pip import *
try:
from ._version import version as __version__
except ImportError:
__version__ = "UNKNOWN"

View File

@@ -1,4 +1,4 @@
from .cli import main
from ._cli import main
if __name__ == "__main__":
main()

View File

@@ -2,22 +2,10 @@ import platform
import re
import subprocess
from abc import ABC, abstractmethod
from typing import Any, Optional, Set
from typing import Any, List, Optional, Set
from pip._vendor.packaging.version import InvalidVersion, Version
__all__ = [
"ComputationBackend",
"CPUBackend",
"CUDABackend",
"detect_compatible_computation_backends",
]
class ParseError(ValueError):
def __init__(self, string: str) -> None:
super().__init__(f"Unable to parse {string} into a computation backend")
class ComputationBackend(ABC):
@property
@@ -45,8 +33,8 @@ class ComputationBackend(ABC):
@classmethod
def from_str(cls, string: str) -> "ComputationBackend":
parse_error = ParseError(string)
string = string.lower()
parse_error = ValueError(f"Unable to parse {string} into a computation backend")
string = string.strip().lower()
if string == "cpu":
return CPUBackend()
elif string.startswith("cu"):
@@ -97,21 +85,30 @@ class CUDABackend(ComputationBackend):
def _detect_nvidia_driver_version() -> Optional[Version]:
cmd = "nvidia-smi --query-gpu=driver_version --format=csv"
try:
output = (
subprocess.check_output(cmd, shell=True, stderr=subprocess.DEVNULL)
.decode("utf-8")
.strip()
result = subprocess.run(
[
"nvidia-smi",
"--query-gpu=driver_version",
"--format=csv",
],
check=True,
capture_output=True,
text=True,
)
return Version(output.splitlines()[-1])
except (subprocess.CalledProcessError, InvalidVersion):
return Version(result.stdout.splitlines()[-1])
except (FileNotFoundError, subprocess.CalledProcessError, InvalidVersion):
return None
# Table 3 from https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
_MINIMUM_DRIVER_VERSIONS = {
"Linux": {
Version("11.6"): Version("510.39.01"),
Version("11.5"): Version("495.29.05"),
Version("11.4"): Version("470.82.01"),
Version("11.3"): Version("465.19.01"),
Version("11.2"): Version("460.32.03"),
Version("11.1"): Version("455.32"),
Version("11.0"): Version("450.51.06"),
Version("10.2"): Version("440.33"),
@@ -121,9 +118,13 @@ _MINIMUM_DRIVER_VERSIONS = {
Version("9.1"): Version("390.46"),
Version("9.0"): Version("384.81"),
Version("8.0"): Version("375.26"),
Version("7.5"): Version("352.31"),
},
"Windows": {
Version("11.6"): Version("511.23"),
Version("11.5"): Version("496.13"),
Version("11.4"): Version("472.50"),
Version("11.3"): Version("465.89"),
Version("11.2"): Version("461.33"),
Version("11.1"): Version("456.81"),
Version("11.0"): Version("451.82"),
Version("10.2"): Version("441.22"),
@@ -133,26 +134,25 @@ _MINIMUM_DRIVER_VERSIONS = {
Version("9.1"): Version("391.29"),
Version("9.0"): Version("385.54"),
Version("8.0"): Version("376.51"),
Version("7.5"): Version("353.66"),
},
}
def _detect_compatible_cuda_backends() -> Set[CUDABackend]:
def _detect_compatible_cuda_backends() -> List[CUDABackend]:
driver_version = _detect_nvidia_driver_version()
if not driver_version:
return set()
return []
minimum_driver_versions = _MINIMUM_DRIVER_VERSIONS.get(platform.system())
if not minimum_driver_versions:
return set()
return []
return {
return [
CUDABackend(cuda_version.major, cuda_version.minor)
for cuda_version, minimum_driver_version in minimum_driver_versions.items()
if driver_version >= minimum_driver_version
}
]
def detect_compatible_computation_backends() -> Set[ComputationBackend]:
return {CPUBackend(), *_detect_compatible_cuda_backends()}
return {*_detect_compatible_cuda_backends(), CPUBackend()}

5
light_the_torch/_cli.py Normal file
View File

@@ -0,0 +1,5 @@
from pip._internal.cli.main import main as pip_main
from ._patch import patch
main = patch(pip_main)

294
light_the_torch/_patch.py Normal file
View File

@@ -0,0 +1,294 @@
import contextlib
import dataclasses
import enum
import functools
import optparse
import re
import sys
import unittest.mock
from typing import List, Set
from unittest import mock
import pip._internal.cli.cmdoptions
import pip._internal.index.collector
import pip._internal.index.package_finder
from pip._internal.index.package_finder import CandidateEvaluator
from pip._internal.models.candidate import InstallationCandidate
from pip._internal.models.search_scope import SearchScope
import light_the_torch as ltt
from . import _cb as cb
from ._utils import apply_fn_patch
class Channel(enum.Enum):
STABLE = enum.auto()
TEST = enum.auto()
NIGHTLY = enum.auto()
LTS = enum.auto()
@classmethod
def from_str(cls, string):
return cls[string.upper()]
PYTORCH_DISTRIBUTIONS = ("torch", "torchvision", "torchaudio", "torchtext")
def patch(pip_main):
@functools.wraps(pip_main)
def wrapper(argv=None):
if argv is None:
argv = sys.argv[1:]
with apply_patches(argv):
return pip_main(argv)
return wrapper
# adapted from https://stackoverflow.com/a/9307174
class PassThroughOptionParser(optparse.OptionParser):
def __init__(self):
super().__init__(add_help_option=False)
def _process_args(self, largs, rargs, values):
while rargs:
try:
super()._process_args(largs, rargs, values)
except (optparse.BadOptionError, optparse.AmbiguousOptionError) as error:
largs.append(error.opt_str)
@dataclasses.dataclass
class LttOptions:
computation_backends: Set[cb.ComputationBackend] = dataclasses.field(
default_factory=lambda: {cb.CPUBackend()}
)
channel: Channel = Channel.STABLE
@staticmethod
def computation_backend_parser_options():
return [
optparse.Option(
"--pytorch-computation-backend",
# TODO: describe multiple inputs
help=(
"Computation backend for compiled PyTorch distributions, "
"e.g. 'cu102', 'cu115', or 'cpu'. "
"If not specified, the computation backend is detected from the "
"available hardware, preferring CUDA over CPU."
),
),
optparse.Option(
"--cpuonly",
action="store_true",
help=(
"Shortcut for '--pytorch-computation-backend=cpu'. "
"If '--computation-backend' is used simultaneously, "
"it takes precedence over '--cpuonly'."
),
),
]
@staticmethod
def channel_parser_option() -> optparse.Option:
return optparse.Option(
"--pytorch-channel",
# FIXME add help text
help="",
)
@staticmethod
def _parse(argv):
parser = PassThroughOptionParser()
for option in LttOptions.computation_backend_parser_options():
parser.add_option(option)
parser.add_option(LttOptions.channel_parser_option())
parser.add_option("--pre", dest="pre", action="store_true")
opts, _ = parser.parse_args(argv)
return opts
@classmethod
def from_pip_argv(cls, argv: List[str]):
if not argv or argv[0] != "install":
return cls()
opts = cls._parse(argv)
if opts.pytorch_computation_backend is not None:
cbs = {
cb.ComputationBackend.from_str(string)
for string in opts.pytorch_computation_backend.split(",")
}
elif opts.cpuonly:
cbs = {cb.CPUBackend()}
else:
cbs = cb.detect_compatible_computation_backends()
if opts.pytorch_channel is not None:
channel = Channel.from_str(opts.pytorch_channel)
elif opts.pre:
channel = Channel.TEST
else:
channel = Channel.STABLE
return cls(cbs, channel)
@contextlib.contextmanager
def apply_patches(argv):
options = LttOptions.from_pip_argv(argv)
patches = [
patch_cli_version(),
patch_cli_options(),
patch_link_collection(options.computation_backends, options.channel),
patch_candidate_selection(options.computation_backends),
]
with contextlib.ExitStack() as stack:
for patch in patches:
stack.enter_context(patch)
yield stack
@contextlib.contextmanager
def patch_cli_version():
with apply_fn_patch(
"pip",
"_internal",
"cli",
"main_parser",
"get_pip_version",
postprocessing=lambda input, output: f"ltt {ltt.__version__} from {ltt.__path__[0]}\n{output}",
):
yield
@contextlib.contextmanager
def patch_cli_options():
def postprocessing(input, output):
for option in LttOptions.computation_backend_parser_options():
input.cmd_opts.add_option(option)
index_group = pip._internal.cli.cmdoptions.index_group
with apply_fn_patch(
"pip",
"_internal",
"cli",
"cmdoptions",
"add_target_python_options",
postprocessing=postprocessing,
):
with unittest.mock.patch.dict(index_group):
options = index_group["options"].copy()
options.append(LttOptions.channel_parser_option)
index_group["options"] = options
yield
@contextlib.contextmanager
def patch_link_collection(computation_backends, channel):
if channel == channel != Channel.LTS:
find_links = []
# TODO: this template is not valid for all backends
channel_path = f"{channel.name.lower()}/" if channel != Channel.STABLE else ""
index_urls = [
f"https://download.pytorch.org/whl/{channel_path}{backend}"
for backend in sorted(computation_backends)
]
else:
# TODO: expand this when there are more LTS versions
# TODO: switch this to index_urls when
# https://github.com/pytorch/pytorch/pull/74753 is resolved
find_links = ["https://download.pytorch.org/whl/lts/1.8/torch_lts.html"]
index_urls = []
search_scope = SearchScope.create(find_links=find_links, index_urls=index_urls)
@contextlib.contextmanager
def context(input):
if input.project_name not in PYTORCH_DISTRIBUTIONS:
yield
return
with mock.patch.object(input.self, "search_scope", search_scope):
yield
with apply_fn_patch(
"pip",
"_internal",
"index",
"collector",
"LinkCollector",
"collect_sources",
context=context,
):
yield
@contextlib.contextmanager
def patch_candidate_selection(computation_backends):
allowed_locals = {None, *computation_backends}
computation_backend_pattern = re.compile(
r"^/whl/(?P<computation_backend>(cpu|cu\d+))/"
)
def preprocessing(input):
input.candidates = [
candidate
for candidate in input.candidates
if candidate.name not in PYTORCH_DISTRIBUTIONS
or candidate.version.local is None
or "rocm" not in candidate.version.local
]
return input
def postprocessing(
input, output: List[InstallationCandidate]
) -> List[InstallationCandidate]:
return [
candidate
for candidate in output
if candidate.name not in PYTORCH_DISTRIBUTIONS
or candidate.version.local in allowed_locals
]
foo = CandidateEvaluator._sort_key
def sort_key(candidate_evaluator, candidate):
if candidate.name not in PYTORCH_DISTRIBUTIONS:
return foo(candidate_evaluator, candidate)
if candidate.version.local is not None:
computation_backend_str = candidate.version.local.replace("any", "cpu")
else:
match = computation_backend_pattern.match(candidate.link.path)
computation_backend_str = match["computation_backend"] if match else "cpu"
return (
cb.ComputationBackend.from_str(computation_backend_str),
candidate.version,
)
with apply_fn_patch(
"pip",
"_internal",
"index",
"package_finder",
"CandidateEvaluator",
"get_applicable_candidates",
preprocessing=preprocessing,
postprocessing=postprocessing,
):
with unittest.mock.patch.object(CandidateEvaluator, "_sort_key", new=sort_key):
yield

View File

@@ -1,2 +0,0 @@
from .extract import *
from .find import *

View File

@@ -1,103 +0,0 @@
import optparse
from typing import Any, Iterable, Type, TypeVar, cast
from pip._internal.commands.install import InstallCommand
from pip._internal.resolution.base import BaseResolver
from pip._internal.resolution.legacy.resolver import Resolver
from pip._internal.utils.logging import setup_logging
from pip._internal.utils.temp_dir import global_tempdir_manager, tempdir_registry
__all__ = [
"InternalLTTError",
"PatchedInstallCommand",
"make_pip_install_parser",
"run",
"new_from_similar",
"PatchedResolverBase",
]
class InternalLTTError(RuntimeError):
def __init__(self) -> None:
msg = (
"Unexpected internal ltt error. If you ever encounter this "
"message during normal operation, please submit a bug report at "
"https://github.com/pmeier/light-the-torch/issues"
)
super().__init__(msg)
class PatchedInstallCommand(InstallCommand):
def __init__(
self, name: str = "name", summary: str = "summary", **kwargs: Any
) -> None:
super().__init__(name, summary, **kwargs)
def make_pip_install_parser() -> optparse.OptionParser:
return cast(optparse.OptionParser, PatchedInstallCommand().parser)
def get_verbosity(options: optparse.Values, verbose: bool) -> int:
if not verbose:
return -1
return cast(int, options.verbose) - cast(int, options.quiet)
def run(
cmd: InstallCommand, args: Iterable[str], options: optparse.Values, verbose: bool
) -> int:
with cmd.main_context():
cmd.tempdir_registry = cmd.enter_context(tempdir_registry())
cmd.enter_context(global_tempdir_manager())
setup_logging(
verbosity=get_verbosity(options, verbose),
no_color=options.no_color,
user_log_file=options.log,
)
return cast(int, cmd.run(options, list(args)))
def get_public_or_private_attr(obj: Any, name: str) -> Any:
try:
return getattr(obj, name)
except AttributeError:
try:
return getattr(obj, f"_{name}")
except AttributeError:
msg = f"'{type(obj)}' has no attribute '{name}' or '_{name}'"
raise AttributeError(msg)
T = TypeVar("T")
def new_from_similar(cls: Type[T], obj: Any, names: Iterable[str], **kwargs: Any) -> T:
attrs = {name: get_public_or_private_attr(obj, name) for name in names}
attrs.update(kwargs)
return cls(**attrs) # type: ignore[call-arg]
class PatchedResolverBase(Resolver):
@classmethod
def from_resolver(cls, resolver: BaseResolver) -> "PatchedResolverBase":
return new_from_similar(
cls,
resolver,
(
"preparer",
"finder",
"wheel_cache",
"upgrade_strategy",
"force_reinstall",
"ignore_dependencies",
"ignore_installed",
"ignore_requires_python",
"use_user_site",
"make_install_req",
"py_version_info",
),
)

View File

@@ -1,105 +0,0 @@
import contextlib
import re
from typing import Any, List, NoReturn, cast
from pip._internal.req.req_install import InstallRequirement
from pip._internal.req.req_set import RequirementSet
from ..compatibility import find_compatible_torch_version
from .common import InternalLTTError, PatchedInstallCommand, PatchedResolverBase, run
__all__ = ["extract_dists"]
def extract_dists(pip_install_args: List[str], verbose: bool = False) -> List[str]:
"""Extract direct or indirect required PyTorch distributions.
Args:
pip_install_args: Arguments passed to ``pip install`` that will be searched for
required PyTorch distributions
verbose: If ``True``, print additional information to STDOUT.
Returns:
Resolved required PyTorch distributions.
"""
cmd = StopAfterPytorchDistsFoundInstallCommand()
options, args = cmd.parser.parse_args(pip_install_args)
try:
run(cmd, args, options, verbose)
except PytorchDistsFound as resolution:
return resolution.dists
else:
raise InternalLTTError
class PytorchDistsFound(RuntimeError):
def __init__(self, dists: List[str]) -> None:
self.dists = dists
class StopAfterPytorchDistsFoundResolver(PatchedResolverBase):
PYTORCH_CORE = "torch"
PYTORCH_SUBS = ("vision", "text", "audio")
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._pytorch_dists = (
self.PYTORCH_CORE,
*[f"{self.PYTORCH_CORE}{sub}" for sub in self.PYTORCH_SUBS],
)
self._pytorch_core_pattern = re.compile(
f"^{self.PYTORCH_CORE}(?!({'|'.join(self.PYTORCH_SUBS)}))"
)
self._required_pytorch_dists: List[str] = []
def _resolve_one(
self, requirement_set: RequirementSet, req_to_install: InstallRequirement
) -> List[InstallRequirement]:
if req_to_install.name not in self._pytorch_dists:
return cast(
List[InstallRequirement],
super()._resolve_one(requirement_set, req_to_install),
)
self._required_pytorch_dists.append(str(req_to_install.req))
return []
def resolve(
self, root_reqs: List[InstallRequirement], check_supported_wheels: bool
) -> NoReturn:
super().resolve(root_reqs, check_supported_wheels)
raise PytorchDistsFound(self.required_pytorch_dists)
@property
def required_pytorch_dists(self) -> List[str]:
dists = self._required_pytorch_dists
if not dists:
return []
# If the distribution was found in an extra requirement, pip appends this as
# additional information. We remove that here.
dists = [dist.split(";")[0] for dist in dists]
if not any(self._pytorch_core_pattern.match(dist) for dist in dists):
torch = self.PYTORCH_CORE
with contextlib.suppress(RuntimeError):
torch_versions = {
find_compatible_torch_version(*dist.split("=="))
for dist in dists
if "==" in dist
}
if len(torch_versions) == 1:
torch = f"{torch}=={torch_versions.pop()}"
dists.insert(0, torch)
return dists
class StopAfterPytorchDistsFoundInstallCommand(PatchedInstallCommand):
def make_resolver(
self, *args: Any, **kwargs: Any
) -> StopAfterPytorchDistsFoundResolver:
resolver = super().make_resolver(*args, **kwargs)
return StopAfterPytorchDistsFoundResolver.from_resolver(resolver)

View File

@@ -1,400 +0,0 @@
import re
from typing import (
Any,
Collection,
Iterable,
List,
NoReturn,
Optional,
Set,
Text,
Tuple,
Union,
cast,
)
from pip._internal.index.collector import LinkCollector
from pip._internal.index.package_finder import (
CandidateEvaluator,
CandidatePreferences,
LinkEvaluator,
PackageFinder,
)
from pip._internal.models.candidate import InstallationCandidate
from pip._internal.models.link import Link
from pip._internal.models.search_scope import SearchScope
from pip._internal.req.req_install import InstallRequirement
from pip._internal.req.req_set import RequirementSet
from pip._vendor.packaging.version import Version
import light_the_torch.computation_backend as cb
from .common import (
InternalLTTError,
PatchedInstallCommand,
PatchedResolverBase,
new_from_similar,
run,
)
from .extract import extract_dists
__all__ = ["find_links"]
def find_links(
pip_install_args: List[str],
computation_backends: Optional[
Union[cb.ComputationBackend, Collection[cb.ComputationBackend]]
] = None,
channel: str = "stable",
platform: Optional[str] = None,
python_version: Optional[str] = None,
verbose: bool = False,
) -> List[str]:
"""Find wheel links for direct or indirect PyTorch distributions with given
properties.
Args:
pip_install_args: Arguments passed to ``pip install`` that will be searched for
required PyTorch distributions
computation_backends: Collection of supported computation backends, for example
``"cpu"`` or ``"cu102"``. Defaults to the available hardware of the running
system.
channel: Channel of the PyTorch wheels. Can be one of ``"stable"`` (default),
``"lts"``, ``"test"``, and ``"nightly"``.
platform: Platform, for example ``"linux_x86_64"`` or ``"win_amd64"``. Defaults
to the platform of the running system.
python_version: Python version, for example ``"3"`` or ``"3.7"``. Defaults to
the version of the running interpreter.
verbose: If ``True``, print additional information to STDOUT.
Returns:
Wheel links with given properties for all required PyTorch distributions.
"""
if computation_backends is None:
computation_backends = cb.detect_compatible_computation_backends()
elif isinstance(computation_backends, cb.ComputationBackend):
computation_backends = {computation_backends}
else:
computation_backends = set(computation_backends)
if channel not in ("stable", "lts", "test", "nightly"):
raise ValueError(
f"channel can be one of 'stable', 'lts', 'test', or 'nightly', "
f"but got {channel} instead."
)
dists = extract_dists(pip_install_args)
cmd = StopAfterPytorchLinksFoundCommand(
computation_backends=computation_backends, channel=channel
)
pip_install_args = adjust_pip_install_args(dists, platform, python_version)
options, args = cmd.parser.parse_args(pip_install_args)
try:
run(cmd, args, options, verbose)
except PytorchLinksFound as resolution:
return resolution.links
else:
raise InternalLTTError
def adjust_pip_install_args(
pip_install_args: List[str], platform: Optional[str], python_version: Optional[str]
) -> List[str]:
if platform is None and python_version is None:
return pip_install_args
if platform is not None:
pip_install_args = maybe_add_option(
pip_install_args, "--platform", value=platform
)
if python_version is not None:
pip_install_args = maybe_add_option(
pip_install_args, "--python-version", value=python_version
)
return maybe_set_required_options(pip_install_args)
def maybe_add_option(
args: List[str],
option: str,
value: Optional[str] = None,
aliases: Iterable[str] = (),
) -> List[str]:
if any(arg in args for arg in (option, *aliases)):
return args
additional_args = [option]
if value is not None:
additional_args.append(value)
return additional_args + args
def maybe_set_required_options(pip_install_args: List[str]) -> List[str]:
pip_install_args = maybe_add_option(
pip_install_args, "-t", value=".", aliases=("--target",)
)
pip_install_args = maybe_add_option(
pip_install_args, "--only-binary", value=":all:"
)
return pip_install_args
class PytorchLinksFound(RuntimeError):
def __init__(self, links: List[str]) -> None:
self.links = links
class PytorchLinkEvaluator(LinkEvaluator):
HAS_LOCAL_PATTERN = re.compile(r"[+](cpu|cu\d+)$")
EXTRACT_LOCAL_PATTERN = re.compile(r"^/whl/(?P<local_specifier>(cpu|cu\d+))")
@classmethod
def from_link_evaluator(
cls, link_evaluator: LinkEvaluator
) -> "PytorchLinkEvaluator":
return new_from_similar(
cls,
link_evaluator,
(
"project_name",
"canonical_name",
"formats",
"target_python",
"allow_yanked",
"ignore_requires_python",
),
)
def evaluate_link(self, link: Link) -> Tuple[bool, Optional[Text]]:
output = cast(Tuple[bool, Optional[Text]], super().evaluate_link(link))
is_candidate, result = output
if not is_candidate:
return output
result = cast(Text, result)
has_local = self.HAS_LOCAL_PATTERN.search(result) is not None
if has_local:
return output
return True, f"{result}+{self.extract_computation_backend_from_link(link)}"
def extract_computation_backend_from_link(self, link: Link) -> Optional[str]:
match = self.EXTRACT_LOCAL_PATTERN.match(link.path)
if match is None:
return "any"
return match.group("local_specifier")
class PytorchCandidatePreferences(CandidatePreferences):
def __init__(
self,
*args: Any,
computation_backends: Set[cb.ComputationBackend],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.computation_backends = computation_backends
@classmethod
def from_candidate_preferences(
cls,
candidate_preferences: CandidatePreferences,
computation_backends: Set[cb.ComputationBackend],
) -> "PytorchCandidatePreferences":
return new_from_similar(
cls,
candidate_preferences,
("prefer_binary", "allow_all_prereleases",),
computation_backends=computation_backends,
)
class PytorchCandidateEvaluator(CandidateEvaluator):
_MACOS_PLATFORM_PATTERN = re.compile(r"macosx_\d+_\d+_x86_64")
def __init__(
self,
*args: Any,
computation_backends: Set[cb.ComputationBackend],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.computation_backends = computation_backends
@classmethod
def from_candidate_evaluator(
cls,
candidate_evaluator: CandidateEvaluator,
computation_backends: Set[cb.ComputationBackend],
) -> "PytorchCandidateEvaluator":
return new_from_similar(
cls,
candidate_evaluator,
(
"project_name",
"supported_tags",
"specifier",
"prefer_binary",
"allow_all_prereleases",
"hashes",
),
computation_backends=computation_backends,
)
def _sort_key(
self, candidate: InstallationCandidate
) -> Tuple[cb.ComputationBackend, Version]:
return (
cb.ComputationBackend.from_str(
candidate.version.local.replace("any", "cpu")
),
candidate.version,
)
def get_applicable_candidates(
self, candidates: List[InstallationCandidate]
) -> List[InstallationCandidate]:
return [
candidate
for candidate in super().get_applicable_candidates(candidates)
if candidate.version.local in self.computation_backends
or candidate.version.local == "any"
]
class PytorchLinkCollector(LinkCollector):
def __init__(
self,
*args: Any,
computation_backends: Set[cb.ComputationBackend],
channel: str = "stable",
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
if channel == "stable":
urls = ["https://download.pytorch.org/whl/torch_stable.html"]
elif channel == "lts":
urls = ["https://download.pytorch.org/whl/lts/1.8/torch_lts.html"]
else:
urls = [
f"https://download.pytorch.org/whl/"
f"{channel}/{backend}/torch_{channel}.html"
for backend in sorted(computation_backends, key=str)
]
self.search_scope = SearchScope.create(find_links=urls, index_urls=[])
@classmethod
def from_link_collector(
cls,
link_collector: LinkCollector,
computation_backends: Set[cb.ComputationBackend],
channel: str = "stable",
) -> "PytorchLinkCollector":
return new_from_similar(
cls,
link_collector,
("session", "search_scope"),
channel=channel,
computation_backends=computation_backends,
)
class PytorchPackageFinder(PackageFinder):
_candidate_prefs: PytorchCandidatePreferences
_link_collector: PytorchLinkCollector
def __init__(
self,
*args: Any,
computation_backends: Set[cb.ComputationBackend],
channel: str = "stable",
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self._candidate_prefs = PytorchCandidatePreferences.from_candidate_preferences(
self._candidate_prefs, computation_backends=computation_backends
)
self._link_collector = PytorchLinkCollector.from_link_collector(
self._link_collector,
channel=channel,
computation_backends=computation_backends,
)
def make_candidate_evaluator(
self, *args: Any, **kwargs: Any,
) -> PytorchCandidateEvaluator:
candidate_evaluator = super().make_candidate_evaluator(*args, **kwargs)
return PytorchCandidateEvaluator.from_candidate_evaluator(
candidate_evaluator,
computation_backends=self._candidate_prefs.computation_backends,
)
def make_link_evaluator(self, *args: Any, **kwargs: Any) -> PytorchLinkEvaluator:
link_evaluator = super().make_link_evaluator(*args, **kwargs)
return PytorchLinkEvaluator.from_link_evaluator(link_evaluator)
@classmethod
def from_package_finder(
cls,
package_finder: PackageFinder,
computation_backends: Set[cb.ComputationBackend],
channel: str = "stable",
) -> "PytorchPackageFinder":
return new_from_similar(
cls,
package_finder,
(
"link_collector",
"target_python",
"allow_yanked",
"format_control",
"candidate_prefs",
"ignore_requires_python",
),
computation_backends=computation_backends,
channel=channel,
)
class StopAfterPytorchLinksFoundResolver(PatchedResolverBase):
def _resolve_one(
self, requirement_set: RequirementSet, req_to_install: InstallRequirement
) -> List[InstallRequirement]:
self._populate_link(req_to_install)
return []
def resolve(
self, root_reqs: List[InstallRequirement], check_supported_wheels: bool
) -> NoReturn:
requirement_set = super().resolve(root_reqs, check_supported_wheels)
links = [req.link.url for req in requirement_set.all_requirements]
raise PytorchLinksFound(links)
class StopAfterPytorchLinksFoundCommand(PatchedInstallCommand):
def __init__(
self,
computation_backends: Set[cb.ComputationBackend],
channel: str = "stable",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.computation_backends = computation_backends
self.channel = channel
def _build_package_finder(self, *args: Any, **kwargs: Any) -> PytorchPackageFinder:
package_finder = super()._build_package_finder(*args, **kwargs)
return PytorchPackageFinder.from_package_finder(
package_finder,
computation_backends=self.computation_backends,
channel=self.channel,
)
def make_resolver(
self, *args: Any, **kwargs: Any
) -> StopAfterPytorchLinksFoundResolver:
resolver = super().make_resolver(*args, **kwargs)
return StopAfterPytorchLinksFoundResolver.from_resolver(resolver)

113
light_the_torch/_utils.py Normal file
View File

@@ -0,0 +1,113 @@
import contextlib
import functools
import importlib
import inspect
import itertools
from unittest import mock
class InternalError(RuntimeError):
def __init__(self) -> None:
# TODO: check against pip version
# TODO: fix wording
msg = (
"Unexpected internal pytorch-pip-shim error. If you ever encounter this "
"message during normal operation, please submit a bug report at "
"https://github.com/pmeier/pytorch-pip-shim/issues"
)
super().__init__(msg)
class Input(dict):
def __init__(self, fn, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__fn__ = fn
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value) -> None:
self[key] = value
def __delattr__(self, key) -> None:
del self[key]
@classmethod
def from_call_args(cls, fn, *args, **kwargs):
params = iter(inspect.signature(fn).parameters.values())
for arg, param in zip(args, params):
kwargs[param.name] = arg
for param in params:
if (
param.name not in kwargs
and param.default is not inspect.Parameter.empty
):
kwargs[param.name] = param.default
return cls(fn, kwargs)
def to_call_args(self):
params = iter(inspect.signature(self.__fn__).parameters.values())
args = []
for param in params:
if param.kind != inspect.Parameter.POSITIONAL_ONLY:
break
args.append(self[param.name])
else:
return (), dict()
args = tuple(args)
kwargs = dict()
sentinel = object()
for param in itertools.chain([param], params):
kwarg = self.get(param.name, sentinel)
if kwarg is not sentinel:
kwargs[param.name] = kwarg
return args, kwargs
@contextlib.contextmanager
def apply_fn_patch(
*parts,
preprocessing=lambda input: input,
context=contextlib.nullcontext,
postprocessing=lambda input, output: output,
):
target = ".".join(parts)
fn = import_fn(target)
@functools.wraps(fn)
def new(*args, **kwargs):
input = Input.from_call_args(fn, *args, **kwargs)
input = preprocessing(input)
with context(input):
args, kwargs = input.to_call_args()
output = fn(*args, **kwargs)
return postprocessing(input, output)
with mock.patch(target, new=new):
yield
def import_fn(target: str):
attrs = []
name = target
while name:
try:
module = importlib.import_module(name)
break
except ImportError:
name, attr = name.rsplit(".", 1)
attrs.append(attr)
else:
raise InternalError
obj = module
for attr in attrs[::-1]:
obj = getattr(obj, attr)
return obj

View File

@@ -1,176 +0,0 @@
import argparse
import sys
from typing import List, Optional, Sequence, Tuple
from .._pip.common import make_pip_install_parser
from .commands import make_command
__all__ = ["main"]
def main(args: Optional[List[str]] = None) -> None:
args = parse_args(args)
cmd = make_command(args)
try:
pip_install_args = args.args
except AttributeError:
pip_install_args = []
cmd.run(pip_install_args)
def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
if args is None:
args = sys.argv[1:]
parser = make_ltt_parser()
return parser.parse_args(args)
class LTTParser(argparse.ArgumentParser):
def parse_known_args(
self,
args: Optional[Sequence[str]] = None,
namespace: Optional[argparse.Namespace] = None,
) -> Tuple[argparse.Namespace, List[str]]:
args, argv = super().parse_known_args(args=args, namespace=namespace)
if not argv:
return args, argv
message = (
f"Unrecognized arguments: {', '.join(argv)}. If they were meant as "
"optional 'pip install' arguments, they have to be passed after a '--' "
"seperator."
)
self.error(message)
@staticmethod
def add_verbose(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--verbose",
action="store_true",
help=(
"print more output to STDOUT. For fine control use -v / --verbose and "
"-q / --quiet of the 'pip install' options"
),
)
@staticmethod
def add_pip_install_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"args",
nargs="*",
help=(
"arguments of 'pip install'. Optional arguments have to be seperated "
"by '--'"
),
)
@staticmethod
def add_common_arguments(parser: argparse.ArgumentParser) -> None:
LTTParser.add_verbose(parser)
LTTParser.add_pip_install_args(parser)
def make_ltt_parser() -> LTTParser:
parser = LTTParser(prog="ltt")
parser.add_argument(
"-V",
"--version",
action="store_true",
help="show light-the-torch version and path and exit",
)
subparsers = parser.add_subparsers(dest="subcommand", title="subcommands")
add_ltt_install_parser(subparsers)
add_ltt_extract_parser(subparsers)
add_ltt_find_parser(subparsers)
return parser
SubParsers = argparse._SubParsersAction
def add_ltt_install_parser(subparsers: SubParsers) -> None:
parser = subparsers.add_parser(
"install",
description=(
"Install PyTorch distributions from the stable releases. The computation "
"backend is auto-detected from the available hardware preferring CUDA "
"over CPU."
),
)
parser.add_argument(
"--force-cpu",
action="store_true",
help="disable computation backend auto-detection and use CPU instead",
)
parser.add_argument(
"--pytorch-only",
action="store_true",
help="install only PyTorch distributions",
)
parser.add_argument(
"--channel",
type=str,
default="stable",
help=(
"Channel of the PyTorch wheels. "
"Can be one of 'stable' (default), 'lts', 'test', or 'nightly'"
),
)
parser.add_argument(
"--install-cmd",
type=str,
default="python -m pip install {packages}",
help=(
"installation command for the PyTorch distributions and additional "
"packages. Defaults to 'python -m pip install {packages}'"
),
)
LTTParser.add_common_arguments(parser)
def add_ltt_extract_parser(subparsers: SubParsers) -> None:
parser = subparsers.add_parser(
"extract", description="Extract required PyTorch distributions"
)
LTTParser.add_common_arguments(parser)
def add_ltt_find_parser(subparsers: SubParsers) -> None:
parser = subparsers.add_parser(
"find", description="Find wheel links for the required PyTorch distributions"
)
parser.add_argument(
"--computation-backend",
type=str,
help=(
"Only use wheels compatible with COMPUTATION_BACKEND, for example 'cu102' "
"or 'cpu'. Defaults to the computation backend of the running system, "
"preferring CUDA over CPU."
),
)
parser.add_argument(
"--channel",
type=str,
default="stable",
help=(
"Channel of the PyTorch wheels. "
"Can be one of 'stable' (default), 'test', or 'nightly'"
),
)
add_pip_install_arguments(parser, "platform", "python_version")
LTTParser.add_common_arguments(parser)
def add_pip_install_arguments(parser: argparse.ArgumentParser, *dests: str) -> None:
pip_install_parser = make_pip_install_parser()
option_group = pip_install_parser.option_groups[0]
for dest in dests:
options = [option for option in option_group.option_list if option.dest == dest]
assert len(options) == 1
option = options[0]
parser.add_argument(*option._short_opts, *option._long_opts, help=option.help)

View File

@@ -1,130 +0,0 @@
import argparse
import subprocess
import sys
import warnings
from abc import ABC, abstractmethod
from os import path
from typing import Dict, List, NoReturn, Optional, Type
import light_the_torch as ltt
from .._pip.common import make_pip_install_parser
from ..computation_backend import CPUBackend
__all__ = ["make_command"]
class Command(ABC):
@abstractmethod
def __init__(self, args: argparse.Namespace) -> None:
pass
def run(self, pip_install_args: List[str]) -> None:
self._run(pip_install_args)
self.exit()
@abstractmethod
def _run(self, pip_install_args: List[str]) -> None:
pass
def exit(self, code: Optional[int] = None, error: bool = False) -> NoReturn:
if code is None:
code = 1 if error else 0
sys.exit(code)
class GlobalCommand(Command):
def __init__(self, args: argparse.Namespace) -> None:
self.version = args.version
def _run(self, pip_install_args: List[str]) -> None:
if self.version:
root = path.abspath(path.join(path.dirname(__file__), ".."))
print(f"{ltt.__name__}=={ltt.__version__} from {root}")
class ExtractCommand(Command):
def __init__(self, args: argparse.Namespace) -> None:
self.verbose = args.verbose
def _run(self, pip_install_args: List[str]) -> None:
dists = ltt.extract_dists(pip_install_args, verbose=self.verbose)
print("\n".join(dists))
class FindCommand(Command):
def __init__(self, args: argparse.Namespace) -> None:
# TODO split by comma
self.computation_backends = args.computation_backend
self.channel = args.channel
self.platform = args.platform
self.python_version = args.python_version
self.verbose = args.verbose
def _run(self, pip_install_args: List[str]) -> None:
links = ltt.find_links(
pip_install_args,
computation_backends=self.computation_backends,
channel=self.channel,
platform=self.platform,
python_version=self.python_version,
verbose=self.verbose,
)
print("\n".join(links))
class InstallCommand(Command):
def __init__(self, args: argparse.Namespace) -> None:
self.force_cpu = args.force_cpu
self.pytorch_only = args.pytorch_only
self.channel = args.channel
install_cmd = args.install_cmd
if "{packages}" not in install_cmd:
self.exit(error=True)
self.install_cmd = install_cmd
self.verbose = args.verbose
def _run(self, pip_install_args: List[str]) -> None:
links = ltt.find_links(
pip_install_args,
computation_backends={CPUBackend()} if self.force_cpu else None,
channel=self.channel,
verbose=self.verbose,
)
if links:
cmd = self.install_cmd.format(packages=" ".join(links))
subprocess.check_call(cmd, shell=True)
else:
warnings.warn(
f"Didn't find any PyTorch distribution in "
f"'{' '.join(pip_install_args)}'",
RuntimeWarning,
)
if self.pytorch_only:
self.exit()
cmd = self.install_cmd.format(packages=self.collect_packages(pip_install_args))
subprocess.check_call(cmd, shell=True)
def collect_packages(self, pip_install_args: List[str]) -> str:
parser = make_pip_install_parser()
options, args = parser.parse_args(pip_install_args)
editables = [f"-e {e}" for e in options.editables]
requirements = [f"-r {r}" for r in options.requirements]
return " ".join(editables + requirements + args)
COMMAD_CLASSES: Dict[Optional[str], Type[Command]] = {
None: GlobalCommand,
"extract": ExtractCommand,
"find": FindCommand,
"install": InstallCommand,
}
def make_command(args: argparse.Namespace) -> Command:
cls = COMMAD_CLASSES[args.subcommand]
return cls(args)

View File

@@ -1,72 +0,0 @@
from typing import Any, Optional
__all__ = ["find_compatible_torch_version"]
class Version:
@classmethod
def from_str(cls, version: str) -> "Version":
parts = version.split(".")
major = int(parts[0])
minor = int(parts[1]) if len(parts) > 1 else None
patch = int(parts[2]) if len(parts) > 2 else None
return cls(major, minor, patch)
def __init__(self, major: int, minor: Optional[int], patch: Optional[int]) -> None:
self.major = major
self.minor = minor
self.patch = patch
self.parts = (major, minor, patch)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Version):
return False
return all(
[
self_part == other_part
for self_part, other_part in zip(self.parts, other.parts)
if self_part is not None and other_part is not None
]
)
def __hash__(self) -> int:
return hash(self.parts)
def __repr__(self) -> str:
return ".".join([str(part) for part in self.parts if part is not None])
COMPATIBILITY = {
"torchvision": {
Version(0, 9, 1): Version(1, 8, 1),
Version(0, 9, 0): Version(1, 8, 0),
Version(0, 8, 0): Version(1, 7, 0),
Version(0, 7, 0): Version(1, 6, 0),
Version(0, 6, 1): Version(1, 5, 1),
Version(0, 6, 0): Version(1, 5, 0),
Version(0, 5, 0): Version(1, 4, 0),
Version(0, 4, 2): Version(1, 3, 1),
Version(0, 4, 1): Version(1, 3, 0),
Version(0, 4, 0): Version(1, 2, 0),
Version(0, 3, 0): Version(1, 1, 0),
Version(0, 2, 2): Version(1, 0, 1),
}
}
def find_compatible_torch_version(dist: str, version: str) -> str:
version = Version.from_str(version)
dist_compatibility = COMPATIBILITY[dist]
candidates = [x for x in dist_compatibility.keys() if x == version]
if not candidates:
raise RuntimeError(
f"No compatible torch version was found for {dist}=={version}"
)
if len(candidates) != 1:
raise RuntimeError(
f"Multiple compatible torch versions were found for {dist}=={version}:\n"
f"{', '.join([str(candidate) for candidate in candidates])}\n"
)
return str(dist_compatibility[candidates[0]])

View File

@@ -1,34 +0,0 @@
[mypy]
; https://mypy.readthedocs.io/en/stable/config_file.html
; import discovery
files = light_the_torch
; untyped definitions and calls
disallow_untyped_defs = True
; None and Optional handling
no_implicit_optional = True
; warnings
warn_redundant_casts = True
warn_unused_ignores = True
warn_return_any = True
warn_unreachable = True
; miscellaneous strictness flags
allow_redefinition = True
; configuring error messages
show_error_context = True
show_error_codes = True
pretty = True
; miscellaneous
warn_unused_configs = True
[mypy-light_the_torch]
warn_unused_ignores = False
[mypy-pip.*]
ignore_missing_imports = True

View File

@@ -1,10 +1,16 @@
[pytest]
;See link below for available options
;https://docs.pytest.org/en/latest/reference.html#ini-options-ref
;See https://docs.pytest.org/en/latest/reference.html#ini-options-ref for available
; options
markers =
large_download
slow
flaky
testpaths = tests/
addopts = -ra
addopts =
# show summary of all tests that did not pass
-ra
# Make tracebacks shorter
--tb=short
# enable all warnings
-Wd
# coverage
--cov=light_the_torch
--cov-config=.coveragerc
xfail_strict = True
testpaths = tests

View File

@@ -1,2 +1,12 @@
tox >= 3.2
doit
# format & lint
pre-commit
flake8 ==4.0.1
# test
pytest
pytest-mock
pytest-cov
# publish
build
twine
check-wheel-contents

View File

@@ -1,13 +1,13 @@
[metadata]
name = light_the_torch
platforms = any
description = Install PyTorch distributions computation backend auto-detection
long_description = file: README.rst
long_description_content_type = text/x-rst
description = Install PyTorch distributions with computation backend auto-detection
long_description = file: README.md
long_description_content_type = text/markdown; charset=UTF-8; variant=GFM
keywords = pytorch, cuda, pip, install
url = https://github.com/pmeier/light-the-torch
author = Philip Meier
author-email = github.pmeier@posteo.de
author_email = github.pmeier@posteo.de
license = BSD-3-Clause
classifiers =
Development Status :: 3 - Alpha
@@ -15,12 +15,12 @@ classifiers =
Environment :: GPU :: NVIDIA CUDA
Intended Audience :: Developers
License :: OSI Approved :: BSD License
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Topic :: System :: Installation/Setup
Topic :: Utilities
Typing :: Typed
project_urls =
Source = https://github.com/pmeier/light-the-torch
Tracker = https://github.com/pmeier/light-the-torch/issues
@@ -28,15 +28,14 @@ project_urls =
[options]
packages = find:
include_package_data = True
python_requires = >=3.6
python_requires = >=3.7
install_requires =
pip >=20.1.*, <20.3.*
pip ==22.0.*, != 22.0, != 22.0.1, != 22.0.2
[options.packages.find]
exclude =
tests
tests.*
tests*
[options.entry_points]
console_scripts =
ltt=light_the_torch.cli:main
ltt=light_the_torch._cli:main

View File

View File

View File

@@ -1,44 +0,0 @@
import io
import sys
import pytest
@pytest.fixture
def patch_argv(mocker):
def patch_argv_(*args):
return mocker.patch.object(sys, "argv", ["ltt", *args])
return patch_argv_
@pytest.fixture
def patch_extract_dists(mocker):
def patch_extract_dists_(return_value=None):
if return_value is None:
return_value = []
return mocker.patch(
"light_the_torch.cli.commands.ltt.extract_dists", return_value=return_value,
)
return patch_extract_dists_
@pytest.fixture
def patch_find_links(mocker):
def patch_find_links_(return_value=None):
if return_value is None:
return_value = []
return mocker.patch(
"light_the_torch.cli.commands.ltt.find_links", return_value=return_value,
)
return patch_find_links_
@pytest.fixture
def patch_stdout(mocker):
def patch_stdout_():
return mocker.patch.object(sys, "stdout", io.StringIO())
return patch_stdout_

View File

@@ -1,52 +0,0 @@
import pytest
from light_the_torch import cli
from .utils import exits
@pytest.fixture
def patch_extract_argv(patch_argv):
def patch_extract_argv_(*args):
return patch_argv("extract", *args)
return patch_extract_argv_
def test_ltt_extract(subtests, patch_extract_argv, patch_extract_dists, patch_stdout):
pip_install_args = ["foo"]
dists = ["bar", "baz"]
patch_extract_argv(*pip_install_args)
extract_dists = patch_extract_dists(dists)
stdout = patch_stdout()
with exits():
cli.main()
with subtests.test("extract_dists"):
args, _ = extract_dists.call_args
assert args[0] == pip_install_args
with subtests.test("stdout"):
output = stdout.getvalue().strip()
assert output == "\n".join(dists)
def test_ltt_extract_verbose(patch_extract_argv, patch_extract_dists):
patch_extract_argv("--verbose")
extract_dists = patch_extract_dists([])
with exits():
cli.main()
_, kwargs = extract_dists.call_args
assert "verbose" in kwargs
assert kwargs["verbose"]
def test_extract_unrecognized_argument(patch_extract_argv):
patch_extract_argv("--unrecognized-argument")
with exits(error=True):
cli.main()

View File

@@ -1,52 +0,0 @@
import pytest
from light_the_torch import cli
from .utils import exits
@pytest.fixture
def patch_find_argv(patch_argv):
def patch_find_argv_(*args):
return patch_argv("find", *args)
return patch_find_argv_
def test_ltt_find(subtests, patch_find_argv, patch_find_links, patch_stdout):
pip_install_args = ["foo"]
links = ["bar", "baz"]
patch_find_argv(*pip_install_args)
find_links = patch_find_links(links)
stdout = patch_stdout()
with exits():
cli.main()
with subtests.test("find_links"):
args, _ = find_links.call_args
assert args[0] == pip_install_args
with subtests.test("stdout"):
output = stdout.getvalue().strip()
assert output == "\n".join(links)
def test_ltt_find_verbose(patch_find_argv, patch_find_links):
patch_find_argv("--verbose")
find_links = patch_find_links([])
with exits():
cli.main()
_, kwargs = find_links.call_args
assert "verbose" in kwargs
assert kwargs["verbose"]
def test_find_unrecognized_argument(patch_find_argv):
patch_find_argv("--unrecognized-argument")
with exits(error=True):
cli.main()

View File

@@ -1,46 +0,0 @@
import subprocess
import light_the_torch as ltt
from light_the_torch import cli
from .utils import exits
def test_ltt_main_smoke(subtests):
for arg in ("-h", "-V"):
cmd = f"python -m light_the_torch {arg}"
with subtests.test(cmd=cmd):
subprocess.check_call(cmd, shell=True)
def test_ltt_help_smoke(subtests, patch_argv, patch_stdout):
for arg in ("-h", "--help"):
with subtests.test(arg=arg):
patch_argv(arg)
stdout = patch_stdout()
with exits():
cli.main()
assert stdout.getvalue().strip()
def test_ltt_version(subtests, patch_argv, patch_stdout):
for arg in ("-V", "--version"):
with subtests.test(arg=arg):
patch_argv(arg)
stdout = patch_stdout()
with exits():
cli.main()
output = stdout.getvalue().strip()
assert output.startswith(f"{ltt.__name__}=={ltt.__version__}")
def test_ltt_unknown_subcommand(patch_argv):
subcommand = "unkown"
patch_argv(subcommand)
with exits(error=True):
cli.main()

View File

@@ -1,215 +0,0 @@
import pytest
from light_the_torch import cli
from light_the_torch.computation_backend import CPUBackend
from .utils import exits
@pytest.fixture
def patch_install_argv(patch_argv):
def patch_install_argv_(*args):
return patch_argv("install", *args)
return patch_install_argv_
@pytest.fixture
def patch_subprocess_call(mocker):
def patch_subprocess_call_():
return mocker.patch("light_the_torch.cli.commands.subprocess.check_call")
return patch_subprocess_call_
@pytest.fixture
def patch_collect_packages(mocker):
def patch_collect_packages_(return_value=None):
if return_value is None:
return_value = []
return mocker.patch(
"light_the_torch.cli.commands.InstallCommand.collect_packages",
return_value=return_value,
)
return patch_collect_packages_
def test_ltt_install(
subtests, patch_install_argv, patch_find_links, patch_subprocess_call
):
install_cmd = "python -m pip install {packages}"
pip_install_args = ["foo", "bar"]
links = ["https://foo.org"]
patch_install_argv(*pip_install_args)
patch_find_links(links)
subprocess_call = patch_subprocess_call()
with exits():
cli.main()
assert subprocess_call.call_count == 2
with subtests.test("install PyTorch"):
call_args = subprocess_call.call_args_list[0]
args, _ = call_args
assert args[0] == install_cmd.format(packages=" ".join(links))
with subtests.test("install remainder"):
call_args = subprocess_call.call_args_list[1]
args, _ = call_args
assert args[0] == install_cmd.format(packages=" ".join(pip_install_args))
def test_ltt_install_force_cpu(
patch_install_argv, patch_find_links, patch_subprocess_call, patch_collect_packages,
):
patch_install_argv("--force-cpu")
find_links = patch_find_links()
patch_subprocess_call()
patch_collect_packages()
with exits():
cli.main()
_, kwargs = find_links.call_args
assert "computation_backends" in kwargs
assert set(kwargs["computation_backends"]) == {CPUBackend()}
def test_ltt_install_pytorch_only(
patch_install_argv, patch_find_links, patch_subprocess_call, patch_collect_packages,
):
patch_install_argv("--pytorch-only")
patch_find_links()
patch_subprocess_call()
collect_packages = patch_collect_packages()
with exits():
cli.main()
collect_packages.assert_not_called()
def test_ltt_install_channel(
patch_install_argv, patch_find_links, patch_subprocess_call, patch_collect_packages,
):
channel = "channel"
patch_install_argv(f"--channel={channel}")
find_links = patch_find_links()
patch_subprocess_call()
patch_collect_packages()
with exits():
cli.main()
_, kwargs = find_links.call_args
assert "channel" in kwargs
assert kwargs["channel"] == channel
def test_ltt_install_install_cmd(
patch_install_argv, patch_find_links, patch_subprocess_call,
):
install_cmd = "custom install {packages}"
packages = ["foo", "bar"]
patch_install_argv("--pytorch-only", "--install-cmd", install_cmd)
patch_find_links(packages)
subprocess_call = patch_subprocess_call()
with exits():
cli.main()
args, _ = subprocess_call.call_args
assert args[0] == install_cmd.format(packages=" ".join(packages))
def test_ltt_install_install_cmd_no_subs(patch_install_argv):
patch_install_argv("--install-cmd", "no proper packages substitution")
with exits(error=True):
cli.main()
def test_ltt_install_editables(
patch_install_argv, patch_find_links, patch_subprocess_call,
):
install_cmd = "custom install {packages}"
editables = [".", "foo"]
args = ["bar", "baz"]
cmd = install_cmd.format(packages=" ".join([f"-e {e}" for e in editables] + args))
patch_install_argv(
"--install-cmd",
install_cmd,
"--",
"-e",
editables[0],
"--editable",
editables[1],
*args,
)
patch_find_links()
subprocess_call = patch_subprocess_call()
with exits():
cli.main()
args, _ = subprocess_call.call_args
assert args[0] == cmd
def test_ltt_install_requirements(
patch_install_argv, patch_find_links, patch_subprocess_call,
):
install_cmd = "custom install {packages}"
requirements = ["requirements.txt", "requirements-dev.txt"]
args = ["foo", "bar"]
cmd = install_cmd.format(
packages=" ".join([f"-r {r}" for r in requirements] + args)
)
patch_install_argv(
"--install-cmd",
install_cmd,
"--",
"-r",
requirements[0],
"--requirement",
requirements[1],
*args,
)
patch_find_links()
subprocess_call = patch_subprocess_call()
with exits():
cli.main()
args, _ = subprocess_call.call_args
assert args[0] == cmd
def test_ltt_install_verbose(
patch_install_argv, patch_find_links, patch_subprocess_call, patch_collect_packages,
):
patch_install_argv("--verbose")
find_links = patch_find_links()
patch_subprocess_call()
patch_collect_packages()
with exits():
cli.main()
_, kwargs = find_links.call_args
assert "verbose" in kwargs
assert kwargs["verbose"]
def test_install_unrecognized_argument(patch_install_argv):
patch_install_argv("--unrecognized-argument")
with exits(error=True):
cli.main()

View File

@@ -1,21 +0,0 @@
import contextlib
import pytest
__all__ = ["exits"]
@contextlib.contextmanager
def exits(code=None, error=False):
with pytest.raises(SystemExit) as info:
yield
ret = info.value.code
if code is not None:
assert ret == code
if error:
assert ret >= 1
else:
assert ret is None or ret == 0

View File

@@ -1,77 +0,0 @@
import pytest
class MarkConfig:
def __init__(
self,
keyword,
run_by_default,
addoption=True,
option=None,
help=None,
condition_for_skip=None,
reason=None,
):
self.addoption = addoption
if option is None:
option = (
f"--{'skip' if run_by_default else 'run'}-{keyword.replace('_', '-')}"
)
self.option = option
if help is None:
help = (
f"{'Skip' if run_by_default else 'Run'} tests decorated with @{keyword}"
)
self.help = help
if condition_for_skip is None:
def condition_for_skip(config, item):
has_keyword = keyword in item.keywords
if run_by_default:
return has_keyword and config.getoption(option)
else:
return has_keyword and not config.getoption(option)
self.condition_for_skip = condition_for_skip
if reason is None:
reason = (
f"Test is {keyword} and {option} was "
f"{'' if run_by_default else 'not '}given."
)
self.marker = pytest.mark.skip(reason=reason)
MARK_CONFIGS = (
MarkConfig(
keyword="large_download",
run_by_default=True,
reason=(
"Test possibly includes a large download and --skip-large-download was "
"given."
),
),
MarkConfig(keyword="slow", run_by_default=True),
MarkConfig(keyword="flaky", run_by_default=False),
)
def pytest_addoption(parser):
for mark_config in MARK_CONFIGS:
if mark_config.addoption:
parser.addoption(
mark_config.option,
action="store_true",
default=False,
help=mark_config.help,
)
def pytest_collection_modifyitems(config, items):
for item in items:
for mark_config in MARK_CONFIGS:
if mark_config.condition_for_skip(config, item):
item.add_marker(mark_config.marker)

115
tests/test_cli.py Normal file
View File

@@ -0,0 +1,115 @@
import contextlib
import io
import shlex
import subprocess
import sys
import light_the_torch as ltt
import pytest
from light_the_torch._cli import main
@pytest.mark.parametrize("cmd", ["ltt", "python -m light_the_torch"])
def test_entry_point_smoke(cmd):
subprocess.run(shlex.split(cmd), shell=False)
@contextlib.contextmanager
def exits(*, should_succeed=True, expected_code=None, check_err=None, check_out=None):
def parse_checker(checker):
if checker is None or callable(checker):
return checker
if isinstance(checker, str):
checker = (checker,)
def check_fn(text):
for phrase in checker:
assert phrase in text
return check_fn
check_err = parse_checker(check_err)
check_out = parse_checker(check_out)
with pytest.raises(SystemExit) as info:
with contextlib.redirect_stderr(io.StringIO()) as raw_err:
with contextlib.redirect_stdout(io.StringIO()) as raw_out:
yield
returned_code = info.value.code or 0
succeeded = returned_code == 0
err = raw_err.getvalue().strip()
out = raw_out.getvalue().strip()
if expected_code is not None:
if returned_code == expected_code:
return
raise AssertionError(
f"Returned and expected return code mismatch: "
f"{returned_code} != {expected_code}."
)
if should_succeed:
if succeeded:
if check_out:
check_out(out)
return
raise AssertionError(
f"Program should have succeeded, but returned code {returned_code} "
f"and printed the following to STDERR: '{err}'."
)
else:
if not succeeded:
if check_err:
check_err(err)
return
raise AssertionError("Program shouldn't have succeeded, but did.")
@pytest.fixture
def set_argv(mocker):
def patch(*options):
return mocker.patch.object(sys, "argv", ["ltt", *options])
return patch
@pytest.mark.parametrize("option", ["-h", "--help"])
def test_help_smoke(set_argv, option):
set_argv(option)
def check_out(out):
assert out
with exits(check_out=check_out):
main()
@pytest.mark.parametrize("option", ["-V", "--version"])
def test_version(set_argv, option):
set_argv(option)
with exits(check_out=f"ltt {ltt.__version__} from {ltt.__path__[0]}"):
main()
@pytest.mark.parametrize(
"option",
[
"--pytorch-computation-backend",
"--cpuonly",
"--pytorch-channel",
],
)
def test_ltt_options_smoke(set_argv, option):
set_argv("install", "--help")
with exits(check_out=option):
main()

View File

@@ -1,8 +1,40 @@
import subprocess
from types import SimpleNamespace
import pytest
from light_the_torch import computation_backend as cb
from light_the_torch import _cb as cb
try:
subprocess.check_call(
"nvidia-smi",
shell=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
NVIDIA_DRIVER_AVAILABLE = True
except subprocess.CalledProcessError:
NVIDIA_DRIVER_AVAILABLE = False
skip_if_nvidia_driver_unavailable = pytest.mark.skipif(
not NVIDIA_DRIVER_AVAILABLE, reason="Requires nVidia driver."
)
class GenericComputationBackend(cb.ComputationBackend):
@property
def local_specifier(self):
return "generic"
def __lt__(self, other):
return NotImplemented
@pytest.fixture
def generic_backend():
return GenericComputationBackend()
class TestComputationBackend:
@@ -42,7 +74,7 @@ class TestComputationBackend:
@pytest.mark.parametrize("string", (("unknown", "cudnn")))
def test_from_str_unknown(self, string):
with pytest.raises(cb.ParseError):
with pytest.raises(ValueError, match=string):
cb.ComputationBackend.from_str(string)
@@ -70,26 +102,12 @@ class TestOrdering:
assert cb.CUDABackend(2, 1) < cb.CUDABackend(10, 0)
try:
subprocess.check_call(
"nvidia-smi", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
)
NVIDIA_DRIVER_AVAILABLE = True
except subprocess.CalledProcessError:
NVIDIA_DRIVER_AVAILABLE = False
skip_if_nvidia_driver_unavailable = pytest.mark.skipif(
not NVIDIA_DRIVER_AVAILABLE, reason="Requires nVidia driver."
)
@pytest.fixture
def patch_nvidia_driver_version(mocker):
def factory(version):
return mocker.patch(
"light_the_torch.computation_backend.subprocess.check_output",
return_value=f"driver_version\n{version}".encode("utf-8"),
"light_the_torch._cb.subprocess.run",
return_value=SimpleNamespace(stdout=f"driver_version\n{version}"),
)
return factory
@@ -129,7 +147,9 @@ def cuda_backends_params():
pytest.param(
system,
str(driver_versions[idx]),
set(cuda_backends[: idx + 1],),
set(
cuda_backends[: idx + 1],
),
id=f"{system.lower()}-normal",
)
)
@@ -142,7 +162,7 @@ def cuda_backends_params():
class TestDetectCompatibleComputationBackends:
def test_no_nvidia_driver(self, mocker):
mocker.patch(
"light_the_torch.computation_backend.subprocess.check_output",
"light_the_torch._cb.subprocess.run",
side_effect=subprocess.CalledProcessError(1, ""),
)
@@ -157,9 +177,7 @@ class TestDetectCompatibleComputationBackends:
nvidia_driver_version,
compatible_cuda_backends,
):
mocker.patch(
"light_the_torch.computation_backend.platform.system", return_value=system
)
mocker.patch("light_the_torch._cb.platform.system", return_value=system)
patch_nvidia_driver_version(nvidia_driver_version)
backends = cb.detect_compatible_computation_backends()

121
tests/test_smoke.py Normal file
View File

@@ -0,0 +1,121 @@
import builtins
import importlib
import os
import pathlib
import re
import sys
import pytest
PACKAGE_NAME = "light_the_torch"
PROJECT_ROOT = (pathlib.Path(__file__).parent / "..").resolve()
PACKAGE_ROOT = PROJECT_ROOT / PACKAGE_NAME
def collect_modules():
def is_private(path):
return pathlib.Path(path).name.startswith("_")
def path_to_module(path):
return str(pathlib.Path(path).with_suffix("")).replace(os.sep, ".")
modules = []
for root, dirs, files in os.walk(PACKAGE_ROOT):
if is_private(root) or "__init__.py" not in files:
del dirs[:]
continue
path = pathlib.Path(root).relative_to(PROJECT_ROOT)
modules.append(path_to_module(path))
for file in files:
if is_private(file) or not file.endswith(".py"):
continue
modules.append(path_to_module(path / file))
return modules
@pytest.mark.parametrize("module", collect_modules())
def test_importability(module):
importlib.import_module(module)
def import_package_under_test():
try:
return importlib.import_module(PACKAGE_NAME)
except Exception as error:
raise RuntimeError(
f"The package '{PACKAGE_NAME}' could not be imported. "
f"Check the results of tests/test_smoke.py::test_importability for details."
) from error
def test_version_installed():
def is_canonical(version):
# Copied from
# https://www.python.org/dev/peps/pep-0440/#appendix-b-parsing-version-strings-with-regular-expressions
return (
re.match(
r"^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$",
version,
)
is not None
)
def is_dev(version):
match = re.search(r"\+g[\da-f]{7}([.]\d{14})?", version)
if match is not None:
return is_canonical(version[: match.span()[0]])
else:
return False
put = import_package_under_test()
assert is_canonical(put.__version__) or is_dev(put.__version__)
def patch_imports(
mocker,
*names,
retain_condition=None,
import_error_condition=None,
):
if retain_condition is None:
def retain_condition(name):
return not any(name.startswith(name_) for name_ in names)
if import_error_condition is None:
def import_error_condition(name, globals, locals, fromlist, level):
direct = name in names
indirect = fromlist is not None and any(
from_ in names for from_ in fromlist
)
return direct or indirect
__import__ = builtins.__import__
def patched_import(name, globals, locals, fromlist, level):
if import_error_condition(name, globals, locals, fromlist, level):
raise ImportError()
return __import__(name, globals, locals, fromlist, level)
mocker.patch.object(builtins, "__import__", new=patched_import)
values = {
name: module for name, module in sys.modules.items() if retain_condition(name)
}
mocker.patch.dict(sys.modules, clear=True, values=values)
def test_version_not_installed(mocker):
def import_error_condition(name, globals, locals, fromlist, level):
return name == "_version" and fromlist == ("version",)
patch_imports(mocker, PACKAGE_NAME, import_error_condition=import_error_condition)
put = import_package_under_test()
assert put.__version__ == "UNKNOWN"

View File

View File

@@ -1,79 +0,0 @@
import itertools
import optparse
import pytest
from light_the_torch._pip import common
def test_get_verbosity(subtests):
verboses = tuple(range(4))
quiets = tuple(range(4))
for verbose, quiet in itertools.product(verboses, quiets):
with subtests.test(verbose=verbose, quiet=quiet):
options = optparse.Values({"verbose": verbose, "quiet": quiet})
verbosity = verbose - quiet
assert common.get_verbosity(options, verbose=True) == verbosity
assert common.get_verbosity(options, verbose=False) == -1
def test_get_public_or_private_attr_public_and_private():
class ObjWithPublicAndPrivateAttribute:
attr = "public"
_attr = "private"
obj = ObjWithPublicAndPrivateAttribute()
assert common.get_public_or_private_attr(obj, "attr") == "public"
def test_get_public_or_private_attr_public_only():
class ObjWithPublicAttribute:
attr = "public"
obj = ObjWithPublicAttribute()
assert common.get_public_or_private_attr(obj, "attr") == "public"
def test_get_public_or_private_attr_private_only():
class ObjWithPrivateAttribute:
_attr = "private"
obj = ObjWithPrivateAttribute()
assert common.get_public_or_private_attr(obj, "attr") == "private"
def test_get_public_or_private_attr_no_attribute():
class ObjWithoutAttribute:
pass
obj = ObjWithoutAttribute()
with pytest.raises(AttributeError):
common.get_public_or_private_attr(obj, "attr")
def test_new_from_similar():
class Object:
def __init__(self, foo, bar="bar"):
self.foo = foo
self._bar = bar
def __eq__(self, other):
foo = self.foo == other.foo
bar = self._bar == other._bar
return foo and bar
class PatchedObj(Object):
def __init__(self, foo, bar="patched_default_bar", baz=None):
super().__init__(foo, bar=bar)
self._baz = baz
obj = Object("foo")
new_obj = common.new_from_similar(PatchedObj, obj, ("foo", "bar"), baz="baz")
assert new_obj is not obj
assert isinstance(new_obj, PatchedObj)
assert new_obj == obj
assert common.get_public_or_private_attr(new_obj, "baz") == "baz"

View File

@@ -1,59 +0,0 @@
import pytest
import light_the_torch as ltt
from light_the_torch._pip import extract
from light_the_torch._pip.common import InternalLTTError
def test_StopAfterPytorchDistsFoundResolver_no_torch(mocker):
mocker.patch(
"light_the_torch._pip.extract.PatchedResolverBase.__init__", return_value=None
)
resolver = extract.StopAfterPytorchDistsFoundResolver()
resolver._required_pytorch_dists = ["torchaudio", "torchtext", "torchvision"]
assert "torch" in resolver.required_pytorch_dists
def test_StopAfterPytorchDistsFoundResolver_torch_compatibility(mocker):
mocker.patch(
"light_the_torch._pip.extract.PatchedResolverBase.__init__", return_value=None
)
resolver = extract.StopAfterPytorchDistsFoundResolver()
resolver._required_pytorch_dists = ["torchvision==0.7"]
assert "torch==1.6.0" in resolver.required_pytorch_dists
def test_extract_pytorch_internal_error(mocker):
mocker.patch("light_the_torch._pip.extract.run")
with pytest.raises(InternalLTTError):
ltt.extract_dists(["foo"])
@pytest.mark.slow
def test_extract_dists_ltt():
assert ltt.extract_dists(["light-the-torch"]) == []
@pytest.mark.slow
def test_extract_dists_pystiche(subtests):
pystiche = "git+https://github.com/pmeier/pystiche@v{}"
reqs_and_dists = (
(pystiche.format("0.4.0"), {"torch>=1.4.0", "torchvision>=0.5.0"}),
(pystiche.format("0.5.0"), {"torch>=1.5.0", "torchvision>=0.6.0"}),
)
for req, dists in reqs_and_dists:
with subtests.test(req=req):
assert set(ltt.extract_dists([req])) == dists
@pytest.mark.slow
def test_extract_dists_kornia(subtests):
kornia = "kornia=={}"
reqs_and_dists = (
(kornia.format("0.2.2"), {"torch<=1.4.0,>=1.0.0"}),
(kornia.format("0.3.1"), {"torch==1.5.0"}),
)
for req, dists in reqs_and_dists:
with subtests.test(req=req):
assert set(ltt.extract_dists([req])) == dists

View File

@@ -1,195 +0,0 @@
import itertools
import pytest
from pip._internal.models.wheel import Wheel
from pip._vendor.packaging.version import Version
import light_the_torch as ltt
import light_the_torch.computation_backend as cb
from light_the_torch._pip.common import InternalLTTError
from light_the_torch._pip.find import maybe_add_option
@pytest.fixture
def patch_extract_dists(mocker):
def patch_extract_dists_(return_value=None):
if return_value is None:
return_value = []
return mocker.patch(
"light_the_torch._pip.find.extract_dists", return_value=return_value
)
return patch_extract_dists_
@pytest.fixture
def patch_run(mocker):
def patch_run_():
return mocker.patch("light_the_torch._pip.find.run")
return patch_run_
CHANNELS = ("stable", "lts", "test", "nightly")
PLATFORMS = ("linux_x86_64", "macosx_10_9_x86_64", "win_amd64")
PLATFORM_MAP = dict(zip(PLATFORMS, ("Linux", "Darwin", "Windows")))
SUPPORTED_PYTHON_VERSIONS = {
Version("11.1"): tuple(f"3.{minor}" for minor in (6, 7, 8, 9)),
Version("11.0"): tuple(f"3.{minor}" for minor in (6, 7, 8, 9)),
Version("10.2"): tuple(f"3.{minor}" for minor in (6, 7, 8, 9)),
Version("10.1"): tuple(f"3.{minor}" for minor in (6, 7, 8, 9)),
Version("10.0"): tuple(f"3.{minor}" for minor in (6, 7, 8)),
Version("9.2"): tuple(f"3.{minor}" for minor in (6, 7, 8, 9)),
Version("9.1"): tuple(f"3.{minor}" for minor in (6,)),
Version("9.0"): tuple(f"3.{minor}" for minor in (6, 7)),
Version("8.0"): tuple(f"3.{minor}" for minor in (6, 7)),
Version("7.5"): tuple(f"3.{minor}" for minor in (6,)),
}
PYTHON_VERSIONS = set(itertools.chain(*SUPPORTED_PYTHON_VERSIONS.values()))
def test_maybe_add_option_already_set():
args = ["--foo", "bar"]
assert maybe_add_option(args, "--foo",) == args
assert maybe_add_option(args, "-f", aliases=("--foo",)) == args
def test_find_links_internal_error(patch_extract_dists, patch_run):
patch_extract_dists()
patch_run()
with pytest.raises(InternalLTTError):
ltt.find_links([])
def test_find_links_computation_backend_detect(
mocker, patch_extract_dists, patch_run, generic_backend
):
computation_backends = {generic_backend}
mocker.patch(
"light_the_torch.computation_backend.detect_compatible_computation_backends",
return_value=computation_backends,
)
patch_extract_dists()
run = patch_run()
with pytest.raises(InternalLTTError):
ltt.find_links([], computation_backends=None)
args, _ = run.call_args
cmd = args[0]
assert cmd.computation_backends == computation_backends
def test_find_links_unknown_channel():
with pytest.raises(ValueError):
ltt.find_links([], channel="channel")
@pytest.mark.parametrize("platform", PLATFORMS)
def test_find_links_platform(patch_extract_dists, patch_run, platform):
patch_extract_dists()
run = patch_run()
with pytest.raises(InternalLTTError):
ltt.find_links([], platform=platform)
args, _ = run.call_args
options = args[2]
assert options.platform == platform
@pytest.mark.parametrize("python_version", PYTHON_VERSIONS)
def test_find_links_python_version(patch_extract_dists, patch_run, python_version):
patch_extract_dists()
run = patch_run()
python_version_tuple = tuple(int(v) for v in python_version.split("."))
with pytest.raises(InternalLTTError):
ltt.find_links([], python_version=python_version)
args, _ = run.call_args
options = args[2]
assert options.python_version == python_version_tuple
def wheel_properties():
params = []
for platform in PLATFORMS:
params.extend(
[
(platform, cb.CPUBackend(), python_version)
for python_version in PYTHON_VERSIONS
]
)
system = PLATFORM_MAP[platform]
cuda_versions = cb._MINIMUM_DRIVER_VERSIONS.get(system, {}).keys()
if not cuda_versions:
continue
params.extend(
[
(
platform,
cb.CUDABackend(cuda_version.major, cuda_version.minor),
python_version,
)
for cuda_version in cuda_versions
for python_version in SUPPORTED_PYTHON_VERSIONS[cuda_version]
if not (
platform == "win_amd64"
and (
(cuda_version == Version("7.5") and python_version == "3.6")
or (cuda_version == Version("9.2") and python_version == "3.9")
or (cuda_version == Version("10.0") and python_version == "3.8")
)
)
]
)
return pytest.mark.parametrize(
("platform", "computation_backend", "python_version"), params, ids=str,
)
@pytest.mark.slow
@pytest.mark.parametrize(
"pytorch_dist", ["torch", "torchaudio", "torchtext", "torchvision"]
)
@wheel_properties()
def test_find_links_stable_smoke(
pytorch_dist, platform, computation_backend, python_version
):
assert ltt.find_links(
[pytorch_dist],
computation_backends=computation_backend,
platform=platform,
python_version=python_version,
)
@pytest.mark.slow
@pytest.mark.parametrize("channel", CHANNELS)
def test_find_links_channel_smoke(channel):
assert ltt.find_links(
["torch"], computation_backends={cb.CPUBackend()}, channel=channel
)
@pytest.mark.parametrize("python_version", PYTHON_VERSIONS)
def test_mac_torch_ge_1_0_0(patch_extract_dists, patch_run, python_version):
# See https://github.com/pmeier/light-the-torch/issues/34
dists = ["torch"]
patch_extract_dists(return_value=dists)
links = ltt.find_links(
dists, python_version=python_version, platform="macosx_10_9_x86_64"
)
version = Version(Wheel(links[0]).version)
assert version >= Version("1.0.0")

View File

@@ -1,17 +0,0 @@
import pytest
import light_the_torch.computation_backend as cb
class GenericComputationBackend(cb.ComputationBackend):
@property
def local_specifier(self):
return "generic"
def __lt__(self, other):
return NotImplemented
@pytest.fixture
def generic_backend():
return GenericComputationBackend()

View File

@@ -1,34 +0,0 @@
import importlib
import pkgutil
import light_the_torch as package_under_test
def test_importability(subtests):
def is_private(name):
return name.rsplit(".", 1)[-1].startswith("_")
def onerror(name):
if is_private(name):
return
with subtests.test(name=name):
raise
for finder, name, is_package in pkgutil.walk_packages(
path=package_under_test.__path__,
prefix=f"{package_under_test.__name__}.",
onerror=onerror,
):
if is_private(name):
continue
if not is_package:
try:
importlib.import_module(name)
except Exception:
onerror(name)
def test_version_availability():
assert isinstance(package_under_test.__version__, str)

60
tox.ini
View File

@@ -1,60 +0,0 @@
[tox]
;See link below for available options
;https://tox.readthedocs.io/en/latest/config.html
isolated_build = True
envlist = py{36, 37, 38}
skip_missing_interpreters = True
[testenv]
deps =
pytest
pytest-subtests
pytest-mock
pytest-cov
commands =
pytest \
-c pytest.ini \
--cov=light_the_torch \
--cov-report=xml \
--cov-config=.coveragerc \
{posargs}
[testenv:format]
requires =
pre-commit
whitelist_externals =
pre-commit
skip_install = True
deps =
commands =
pre-commit run --all-files
[testenv:lint]
whitelist_externals =
pre-commit
requires =
pre-commit
deps =
flake8 >= 3.8
mypy
commands =
pre-commit run --all-files
flake8 \
--config=.flake8
mypy \
--config-file=mypy.ini
[testenv:publishable]
whitelist_externals =
rm
skip_install = True
deps =
check-wheel-contents
pep517
twine
commands =
rm -rf build dist light_the_torch.egg-info
python -m pep517.build --source --binary .
twine check dist/*
check-wheel-contents dist