mirror of
				https://github.com/huggingface/text-generation-inference.git
				synced 2023-08-15 01:09:35 +03:00 
			
		
		
		
	feat(server): add flash attention llama (#144)
This commit is contained in:
		
							
								
								
									
										10
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								README.md
									
									
									
									
									
								
							| @@ -51,16 +51,14 @@ to power LLMs api-inference widgets. | ||||
| - Log probabilities | ||||
| - Production ready (distributed tracing with Open Telemetry, Prometheus metrics) | ||||
|  | ||||
| ## Officially supported architectures | ||||
| ## Optimized architectures | ||||
|  | ||||
| - [BLOOM](https://huggingface.co/bigscience/bloom) | ||||
| - [BLOOMZ](https://huggingface.co/bigscience/bloomz) | ||||
| - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) | ||||
| - [Galactica](https://huggingface.co/facebook/galactica-120b) | ||||
| - [SantaCoder](https://huggingface.co/bigcode/santacoder) | ||||
| - [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b) | ||||
| - [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) | ||||
| - [FLAN-UL2](https://huggingface.co/google/flan-ul2) | ||||
| - [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) | ||||
| - [FLAN-T5](https://huggingface.co/google/flan-t5-xxl) | ||||
| - [Llama](https://github.com/facebookresearch/llama) | ||||
|  | ||||
| Other architectures are supported on a best effort basis using: | ||||
|  | ||||
|   | ||||
| @@ -14,7 +14,7 @@ | ||||
|     "tokens": [ | ||||
|       { | ||||
|         "id": 259, | ||||
|         "text": " ", | ||||
|         "text": "", | ||||
|         "logprob": -1.3656927, | ||||
|         "special": false | ||||
|       }, | ||||
|   | ||||
							
								
								
									
										104
									
								
								server/poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										104
									
								
								server/poetry.lock
									
									
									
										generated
									
									
									
								
							| @@ -517,6 +517,14 @@ tensorflow = ["tensorflow"] | ||||
| testing = ["h5py", "huggingface-hub", "numpy", "pytest", "pytest-benchmark", "setuptools-rust"] | ||||
| torch = ["torch"] | ||||
|  | ||||
| [[package]] | ||||
| name = "sentencepiece" | ||||
| version = "0.1.97" | ||||
| description = "SentencePiece python wrapper" | ||||
| category = "main" | ||||
| optional = false | ||||
| python-versions = "*" | ||||
|  | ||||
| [[package]] | ||||
| name = "setuptools" | ||||
| version = "67.4.0" | ||||
| @@ -530,6 +538,19 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-g | ||||
| testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] | ||||
| testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] | ||||
|  | ||||
| [[package]] | ||||
| name = "tokenizers" | ||||
| version = "0.13.3" | ||||
| description = "Fast and Customizable Tokenizers" | ||||
| category = "main" | ||||
| optional = false | ||||
| python-versions = "*" | ||||
|  | ||||
| [package.extras] | ||||
| dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] | ||||
| docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] | ||||
| testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] | ||||
|  | ||||
| [[package]] | ||||
| name = "tomli" | ||||
| version = "2.0.1" | ||||
| @@ -630,7 +651,7 @@ bnb = ["bitsandbytes"] | ||||
| [metadata] | ||||
| lock-version = "1.1" | ||||
| python-versions = "^3.9" | ||||
| content-hash = "521dc9f3c283dc56f7d2e2f96759919ff27ab49ffd3ae7cd26317b209e7fa98d" | ||||
| content-hash = "1c57379c7b9349d2a860b50b3ab125737a0f6f94f4303d7cb55248cb86ff8b8e" | ||||
|  | ||||
| [metadata.files] | ||||
| accelerate = [ | ||||
| @@ -1116,10 +1137,91 @@ safetensors = [ | ||||
|     {file = "safetensors-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:ba3dc236a2344b7feadc9868307f42ba5e4804c9d68a80a35aac831349b31f6f"}, | ||||
|     {file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"}, | ||||
| ] | ||||
| sentencepiece = [ | ||||
|     {file = "sentencepiece-0.1.97-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6f249c8f1852893be86eae66b19d522c5fb30bbad4fe2d1b07f06fdc86e1907e"}, | ||||
|     {file = "sentencepiece-0.1.97-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:09e1bc53178de70c557a9ba4fece07364b4416ce3d36570726b3372b68aea135"}, | ||||
|     {file = "sentencepiece-0.1.97-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:667193c57fb48b238be7e3d7636cfc8da56cb5bac5559d8f0b647334e1175be8"}, | ||||
|     {file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2780531985af79c6163f63d4f200fec8a28b70b6768d2c19f70d01568a4524e8"}, | ||||
|     {file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:205050670c53ef9015e2a98cce3934bfbcf0aafaa14caa0c618dd5667bc217ee"}, | ||||
|     {file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28b183dadef8e8b6b4645c1c20692d7be0a13ecc3ec1a07b3885c8905516675f"}, | ||||
|     {file = "sentencepiece-0.1.97-cp310-cp310-win32.whl", hash = "sha256:ee3c9dbd558d8d85bb1617087b86df6ea2b856a528669630ce6cedeb4353b823"}, | ||||
|     {file = "sentencepiece-0.1.97-cp310-cp310-win_amd64.whl", hash = "sha256:f7dc55379e2f7dee86537180283db2e5f8418c6825fdd2fe436c724eb5604c05"}, | ||||
|     {file = "sentencepiece-0.1.97-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:ba1b4154f9144c5a7528b00aff5cffaa1a896a1c6ca53ca78b6e74cd2dae5244"}, | ||||
|     {file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac3d90aee5581e55d029d124ac11b6ae2fbae0817863b664b2f2302e966ababb"}, | ||||
|     {file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c27400f1ac46518a01c87cb7703650e4e48728649feb115d2e3f1102a946a42"}, | ||||
|     {file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6e12a166eba75994ca749aadc4a5056b91b31405f805d6de6e8914cc9741c60"}, | ||||
|     {file = "sentencepiece-0.1.97-cp36-cp36m-win32.whl", hash = "sha256:ed85dff5c0a9b3dd1a414c7e1119f2a19e863fc3f81da525bf7f885ebc883de0"}, | ||||
|     {file = "sentencepiece-0.1.97-cp36-cp36m-win_amd64.whl", hash = "sha256:91a19ab6f40ffbae6d6127119953d2c6a85e93d734953dbc8629fde0d21ace66"}, | ||||
|     {file = "sentencepiece-0.1.97-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:bae580e4a35a9314ff49561ac7c06574fe6afc71b821ed6bb00534e571458156"}, | ||||
|     {file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ad7262e7530c683b186672b5dd0082f82719a50a500a8cfbc4bbd7cde5bff8c"}, | ||||
|     {file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:620cee35279720016735a7c7103cddbd9b84fe5e2f098bd5e673834d69fee2b8"}, | ||||
|     {file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93b921b59914c0ec6697e8c6d5e6b44d99d1298fb1a0af56980a79ade0540c19"}, | ||||
|     {file = "sentencepiece-0.1.97-cp37-cp37m-win32.whl", hash = "sha256:9b9a4c44a31d5f47616e9568dcf31e029b0bfa776e0a252c0b59247881598b09"}, | ||||
|     {file = "sentencepiece-0.1.97-cp37-cp37m-win_amd64.whl", hash = "sha256:f31533cdacced56219e239d3459a003ece35116920dd64b2309d4ad047b77644"}, | ||||
|     {file = "sentencepiece-0.1.97-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:7d643c01d1cad13b9206a276bbe5bc1a468e3d7cf6a26bde7783f945277f859d"}, | ||||
|     {file = "sentencepiece-0.1.97-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:542f1985b1ee279a92bef7740ec0781452372028ce01e15aa88df3228b197ba3"}, | ||||
|     {file = "sentencepiece-0.1.97-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:93701da21fea906dd244bf88cdbe640385a89c45d3c1812b76dbadf8782cdbcd"}, | ||||
|     {file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51514047b964047b7fadb480d88a5e0f72c02f6ca1ba96258fbbc6e79274a94"}, | ||||
|     {file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e3ae2e9b7a5b6f2aa64ec9240b0c185dabe597d0e787dc4344acfbaef1ffe0b2"}, | ||||
|     {file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:923ee4af16dbae1f2ab358ed09f8a0eb89e40a8198a8b343bf54181482342721"}, | ||||
|     {file = "sentencepiece-0.1.97-cp38-cp38-win32.whl", hash = "sha256:fa6f2b88850b5fae3a05053658824cf9f147c8e3c3b40eb64539a976c83d8a24"}, | ||||
|     {file = "sentencepiece-0.1.97-cp38-cp38-win_amd64.whl", hash = "sha256:5137ff0d0b1cc574751d178650ef800ff8d90bf21eb9f71e9567d4a0548940a5"}, | ||||
|     {file = "sentencepiece-0.1.97-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f92876271a10494671431ad955bff2d6f8ea59baaf957f5ae5946aff56dfcb90"}, | ||||
|     {file = "sentencepiece-0.1.97-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:35c227b6d55e473033db7e0ecc51b1e99e6ed7607cc08602fb5768132543c81d"}, | ||||
|     {file = "sentencepiece-0.1.97-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1706a8a8188f7b3d4b7922db9bb00c64c4e16ee68ab4caaae79f55b3e18748c7"}, | ||||
|     {file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce61efc1862ccb18856c4aabbd930e13d5bfbb4b09b4f111081ac53a9dc62275"}, | ||||
|     {file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a78c03800ef9f02d320e0159f5768b15357f3e9ebea545c9c4ba7928ba8ba254"}, | ||||
|     {file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753b8088fd685ee787d9f54c84275ab347de558c7c4ebc6accb4c35bf7776f20"}, | ||||
|     {file = "sentencepiece-0.1.97-cp39-cp39-win32.whl", hash = "sha256:24306fd86031c17a1a6ae92671e76a350390a3140a65620bc2843dad7db24e2a"}, | ||||
|     {file = "sentencepiece-0.1.97-cp39-cp39-win_amd64.whl", hash = "sha256:c6641d0b7acec61fde5881ea6ebe098c169557ac9aa3bdabdf124eab5a5592bb"}, | ||||
|     {file = "sentencepiece-0.1.97.tar.gz", hash = "sha256:c901305e0a710bbcd296f66d79e96f744e6e175b29812bd5178318437d4e1f6c"}, | ||||
| ] | ||||
| setuptools = [ | ||||
|     {file = "setuptools-67.4.0-py3-none-any.whl", hash = "sha256:f106dee1b506dee5102cc3f3e9e68137bbad6d47b616be7991714b0c62204251"}, | ||||
|     {file = "setuptools-67.4.0.tar.gz", hash = "sha256:e5fd0a713141a4a105412233c63dc4e17ba0090c8e8334594ac790ec97792330"}, | ||||
| ] | ||||
| tokenizers = [ | ||||
|     {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, | ||||
|     {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, | ||||
|     {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"}, | ||||
|     {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"}, | ||||
|     {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"}, | ||||
|     {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"}, | ||||
|     {file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"}, | ||||
|     {file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"}, | ||||
|     {file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"}, | ||||
|     {file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"}, | ||||
|     {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"}, | ||||
|     {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"}, | ||||
|     {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"}, | ||||
|     {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"}, | ||||
|     {file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"}, | ||||
|     {file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"}, | ||||
|     {file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"}, | ||||
|     {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"}, | ||||
|     {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"}, | ||||
|     {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"}, | ||||
|     {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"}, | ||||
|     {file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"}, | ||||
|     {file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"}, | ||||
|     {file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"}, | ||||
|     {file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"}, | ||||
|     {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"}, | ||||
|     {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"}, | ||||
|     {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"}, | ||||
|     {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"}, | ||||
|     {file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"}, | ||||
|     {file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"}, | ||||
|     {file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"}, | ||||
|     {file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"}, | ||||
|     {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"}, | ||||
|     {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"}, | ||||
|     {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"}, | ||||
|     {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"}, | ||||
|     {file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"}, | ||||
|     {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"}, | ||||
|     {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"}, | ||||
| ] | ||||
| tomli = [ | ||||
|     {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, | ||||
|     {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, | ||||
|   | ||||
| @@ -23,6 +23,8 @@ opentelemetry-api = "^1.15.0" | ||||
| opentelemetry-exporter-otlp = "^1.15.0" | ||||
| opentelemetry-instrumentation-grpc = "^0.36b0" | ||||
| hf-transfer = "^0.1.2" | ||||
| sentencepiece = "^0.1.97" | ||||
| tokenizers = "0.13.3" | ||||
|  | ||||
| [tool.poetry.extras] | ||||
| bnb = ["bitsandbytes"] | ||||
|   | ||||
| @@ -148,7 +148,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) | ||||
|     assert all([generation.generated_text is None for generation in generations]) | ||||
|     assert all([len(generation.prefill_tokens) == 1 for generation in generations]) | ||||
|     assert all([generation.token_id.item() == 259 for generation in generations]) | ||||
|     assert all([generation.token_text == " " for generation in generations]) | ||||
|     assert all([generation.token_text == "" for generation in generations]) | ||||
|     assert generations[0].request_id == 0 | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -19,13 +19,11 @@ from text_generation_server.models.t5 import T5Sharded | ||||
| try: | ||||
|     from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded | ||||
|     from text_generation_server.models.flash_santacoder import FlashSantacoder | ||||
|     from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded | ||||
|  | ||||
|     FLASH_ATTENTION = ( | ||||
|         torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1 | ||||
|     ) | ||||
|     FLASH_ATTENTION = torch.cuda.is_available() | ||||
| except ImportError: | ||||
|     if int(os.environ.get("FLASH_ATTENTION", 0)) == 1: | ||||
|         logger.exception("Could not import Flash Attention models") | ||||
|     logger.exception("Could not import Flash Attention enabled models") | ||||
|     FLASH_ATTENTION = False | ||||
|  | ||||
| __all__ = [ | ||||
| @@ -47,6 +45,12 @@ if FLASH_ATTENTION: | ||||
|     __all__.append(FlashNeoX) | ||||
|     __all__.append(FlashNeoXSharded) | ||||
|     __all__.append(FlashSantacoder) | ||||
|     __all__.append(FlashLlama) | ||||
|     __all__.append(FlashLlamaSharded) | ||||
|  | ||||
| FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention CUDA kernels to be installed.\n" \ | ||||
|                           "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " \ | ||||
|                           "or install flash attention with `cd server && make install install-flash-attention`" | ||||
|  | ||||
| # The flag below controls whether to allow TF32 on matmul. This flag defaults to False | ||||
| # in PyTorch 1.12 and later. | ||||
| @@ -60,7 +64,7 @@ torch.set_grad_enabled(False) | ||||
|  | ||||
|  | ||||
| def get_model( | ||||
|     model_id: str, revision: Optional[str], sharded: bool, quantize: bool | ||||
|         model_id: str, revision: Optional[str], sharded: bool, quantize: bool | ||||
| ) -> Model: | ||||
|     if "facebook/galactica" in model_id: | ||||
|         if sharded: | ||||
| @@ -92,6 +96,17 @@ def get_model( | ||||
|             neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM | ||||
|             return neox_cls(model_id, revision, quantize=quantize) | ||||
|  | ||||
|     if model_type == "llama": | ||||
|         if sharded: | ||||
|             if FLASH_ATTENTION: | ||||
|                 return FlashLlamaSharded(model_id, revision, quantize=quantize) | ||||
|             raise NotImplementedError( | ||||
|                 FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama") | ||||
|             ) | ||||
|         else: | ||||
|             llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM | ||||
|             return llama_cls(model_id, revision, quantize=quantize) | ||||
|  | ||||
|     if model_type == "t5": | ||||
|         if sharded: | ||||
|             return T5Sharded(model_id, revision, quantize=quantize) | ||||
|   | ||||
| @@ -34,6 +34,8 @@ class CausalLMBatch(Batch): | ||||
|  | ||||
|     # Lengths of all generations present in the batch | ||||
|     input_lengths: List[int] | ||||
|     offsets: List[Optional[int]] | ||||
|     token_offsets: List[Optional[int]] | ||||
|  | ||||
|     # Generation helpers | ||||
|     next_token_choosers: List[NextTokenChooser] | ||||
| @@ -64,12 +66,16 @@ class CausalLMBatch(Batch): | ||||
|         inputs = [] | ||||
|         next_token_choosers = [] | ||||
|         stopping_criterias = [] | ||||
|         offsets = [] | ||||
|         token_offsets = [] | ||||
|  | ||||
|         # Parse batch | ||||
|         max_truncation = 0 | ||||
|         padding_right_offset = 0 | ||||
|         for r in pb.requests: | ||||
|             inputs.append(r.inputs) | ||||
|             offsets.append(None) | ||||
|             token_offsets.append(None) | ||||
|             next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) | ||||
|             stopping_criteria = StoppingCriteria.from_pb( | ||||
|                 r.stopping_parameters, tokenizer | ||||
| @@ -113,6 +119,8 @@ class CausalLMBatch(Batch): | ||||
|             past_key_values=None, | ||||
|             all_input_ids=all_input_ids, | ||||
|             input_lengths=input_lengths.tolist(), | ||||
|             offsets=offsets, | ||||
|             token_offsets=token_offsets, | ||||
|             next_token_choosers=next_token_choosers, | ||||
|             stopping_criterias=stopping_criterias, | ||||
|             size=pb.size, | ||||
| @@ -135,6 +143,8 @@ class CausalLMBatch(Batch): | ||||
|         # Batch attributes | ||||
|         requests = [] | ||||
|         input_lengths = [] | ||||
|         offsets = [] | ||||
|         token_offsets = [] | ||||
|         all_input_ids = [] | ||||
|         next_token_choosers = [] | ||||
|         stopping_criterias = [] | ||||
| @@ -151,6 +161,8 @@ class CausalLMBatch(Batch): | ||||
|         for i, batch in enumerate(batches): | ||||
|             requests.extend(batch.requests) | ||||
|             input_lengths.extend(batch.input_lengths) | ||||
|             offsets.extend(batch.offsets) | ||||
|             token_offsets.extend(batch.token_offsets) | ||||
|             all_input_ids.extend(batch.all_input_ids) | ||||
|             next_token_choosers.extend(batch.next_token_choosers) | ||||
|             stopping_criterias.extend(batch.stopping_criterias) | ||||
| @@ -264,6 +276,8 @@ class CausalLMBatch(Batch): | ||||
|             past_key_values=past_key_values, | ||||
|             all_input_ids=all_input_ids, | ||||
|             input_lengths=input_lengths, | ||||
|             offsets=offsets, | ||||
|             token_offsets=token_offsets, | ||||
|             next_token_choosers=next_token_choosers, | ||||
|             stopping_criterias=stopping_criterias, | ||||
|             size=total_batch_size, | ||||
| @@ -289,7 +303,7 @@ class CausalLM(Model): | ||||
|             dtype = torch.float32 | ||||
|  | ||||
|         tokenizer = AutoTokenizer.from_pretrained( | ||||
|             model_id, revision=revision, padding_side="left" | ||||
|             model_id, revision=revision, padding_side="left", truncation_side="left" | ||||
|         ) | ||||
|         self.model = AutoModelForCausalLM.from_pretrained( | ||||
|             model_id, | ||||
| @@ -350,6 +364,8 @@ class CausalLM(Model): | ||||
|  | ||||
|         # New values for next forward | ||||
|         next_batch_input_lengths = [] | ||||
|         next_batch_offsets = [] | ||||
|         next_batch_token_offsets = [] | ||||
|         next_batch_input_ids = [] | ||||
|         next_batch_all_input_ids = [] | ||||
|  | ||||
| @@ -364,6 +380,8 @@ class CausalLM(Model): | ||||
|         iterator = zip( | ||||
|             batch.requests, | ||||
|             batch.input_lengths, | ||||
|             batch.offsets, | ||||
|             batch.token_offsets, | ||||
|             logits, | ||||
|             batch.next_token_choosers, | ||||
|             batch.stopping_criterias, | ||||
| @@ -374,6 +392,8 @@ class CausalLM(Model): | ||||
|         for i, ( | ||||
|             request, | ||||
|             input_length, | ||||
|             offset, | ||||
|             token_offset, | ||||
|             logits, | ||||
|             next_token_chooser, | ||||
|             stopping_criteria, | ||||
| @@ -391,8 +411,8 @@ class CausalLM(Model): | ||||
|             # Generated token | ||||
|             next_token_logprob = logprobs[-1, next_token_id] | ||||
|             next_token_id_squeezed = next_token_id.squeeze() | ||||
|             next_token_text = self.decode_token( | ||||
|                 next_token_id_squeezed, | ||||
|             next_token_text, offset, token_offset = self.decode_token( | ||||
|                 all_input_ids[:, 0], offset, token_offset | ||||
|             ) | ||||
|  | ||||
|             # Evaluate stopping criteria | ||||
| @@ -423,6 +443,8 @@ class CausalLM(Model): | ||||
|                 next_batch_all_input_ids.append(all_input_ids) | ||||
|                 next_batch_size += 1 | ||||
|                 next_batch_input_lengths.append(new_input_length) | ||||
|                 next_batch_offsets.append(offset) | ||||
|                 next_batch_token_offsets.append(token_offset) | ||||
|                 next_batch_max_input_length = max( | ||||
|                     next_batch_max_input_length, new_input_length | ||||
|                 ) | ||||
| @@ -506,6 +528,8 @@ class CausalLM(Model): | ||||
|             past_key_values=next_batch_past_key_values, | ||||
|             all_input_ids=next_batch_all_input_ids, | ||||
|             input_lengths=next_batch_input_lengths, | ||||
|             offsets=next_batch_offsets, | ||||
|             token_offsets=next_batch_token_offsets, | ||||
|             next_token_choosers=next_batch_next_token_choosers, | ||||
|             stopping_criterias=next_batch_stopping_criterias, | ||||
|             size=next_batch_size, | ||||
|   | ||||
| @@ -0,0 +1,619 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. | ||||
| # | ||||
| # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | ||||
| # and OPT implementations in this library. It has been modified from its | ||||
| # original forms to accommodate minor architectural differences compared | ||||
| # to GPT-NeoX and OPT used by the Meta AI team that trained the model. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
| import torch | ||||
| import torch.distributed | ||||
|  | ||||
| from torch.nn import functional as F | ||||
|  | ||||
| from torch import nn | ||||
| from transformers.activations import ACT2FN | ||||
|  | ||||
| # Flash attention imports | ||||
| import rotary_emb | ||||
| import flash_attn_cuda | ||||
| import dropout_layer_norm | ||||
|  | ||||
| from flash_attn.layers.rotary import RotaryEmbedding | ||||
|  | ||||
|  | ||||
| class LlamaRMSNorm(nn.Module): | ||||
|     def __init__(self, hidden_size, eps=1e-6): | ||||
|         """ | ||||
|         LlamaRMSNorm is equivalent to T5LayerNorm | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.weight = nn.Parameter(torch.ones(hidden_size)) | ||||
|         self.variance_epsilon = eps | ||||
|  | ||||
|     def forward(self, hidden_states, residual=None): | ||||
|         if hidden_states.shape[-1] > 8192: | ||||
|             if residual is not None: | ||||
|                 hidden_states += residual | ||||
|             residual = hidden_states | ||||
|  | ||||
|             variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | ||||
|             hidden_states = hidden_states * torch.rsqrt( | ||||
|                 variance + self.variance_epsilon | ||||
|             ) | ||||
|  | ||||
|             # convert into half-precision if necessary | ||||
|             if self.weight.dtype in [torch.float16, torch.bfloat16]: | ||||
|                 hidden_states = hidden_states.to(self.weight.dtype) | ||||
|  | ||||
|             return self.weight * hidden_states, residual | ||||
|         else: | ||||
|             # faster post attention rms norm | ||||
|             normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( | ||||
|                 hidden_states, | ||||
|                 residual, | ||||
|                 self.weight, | ||||
|                 None, | ||||
|                 None, | ||||
|                 None, | ||||
|                 None, | ||||
|                 None, | ||||
|                 0.0, | ||||
|                 self.variance_epsilon, | ||||
|                 1.0, | ||||
|                 0, | ||||
|                 None, | ||||
|                 False, | ||||
|                 True,  # Activate RMSNorm | ||||
|             ) | ||||
|             if res is None: | ||||
|                 res = hidden_states | ||||
|  | ||||
|             return normed_hidden_states, res | ||||
|  | ||||
|  | ||||
| class FastLinear(nn.Linear): | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_features: int, | ||||
|         out_features: int, | ||||
|         bias: bool = True, | ||||
|         device=None, | ||||
|         dtype=None, | ||||
|     ) -> None: | ||||
|         super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) | ||||
|  | ||||
|     def transpose_weight(self): | ||||
|         self.weight = nn.Parameter(self.weight.T) | ||||
|  | ||||
|     def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         if self.bias is not None: | ||||
|             return torch.addmm(self.bias, input, self.weight) | ||||
|         return torch.matmul(input, self.weight) | ||||
|  | ||||
|  | ||||
| class TensorParallelColumnLinear(FastLinear): | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_features, | ||||
|         out_features, | ||||
|         process_group: torch.distributed.ProcessGroup, | ||||
|         bias=True, | ||||
|         device=None, | ||||
|         dtype=None, | ||||
|     ): | ||||
|         self.process_group = process_group | ||||
|         self.tp_world_size = process_group.size() | ||||
|         assert out_features % self.tp_world_size == 0 | ||||
|         out_features = out_features // self.tp_world_size | ||||
|  | ||||
|         super().__init__( | ||||
|             in_features=in_features, | ||||
|             out_features=out_features, | ||||
|             bias=bias, | ||||
|             device=device, | ||||
|             dtype=dtype, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class TensorParallelRowLinear(FastLinear): | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_features, | ||||
|         out_features, | ||||
|         process_group: torch.distributed.ProcessGroup, | ||||
|         reduce=True, | ||||
|         bias=True, | ||||
|         device=None, | ||||
|         dtype=None, | ||||
|     ): | ||||
|         self.process_group = process_group | ||||
|         self.tp_world_size = process_group.size() | ||||
|         self.reduce = reduce | ||||
|         assert in_features % self.tp_world_size == 0 | ||||
|         in_features = in_features // self.tp_world_size | ||||
|  | ||||
|         super().__init__( | ||||
|             in_features=in_features, | ||||
|             out_features=out_features, | ||||
|             bias=bias, | ||||
|             device=device, | ||||
|             dtype=dtype, | ||||
|         ) | ||||
|  | ||||
|     def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         out = super(TensorParallelRowLinear, self).forward(input) | ||||
|         if self.reduce: | ||||
|             torch.distributed.all_reduce(out, group=self.process_group) | ||||
|  | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class TensorParallelEmbedding(nn.Embedding): | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_embeddings, | ||||
|         embedding_dim, | ||||
|         process_group: torch.distributed.ProcessGroup, | ||||
|         padding_idx=None, | ||||
|         max_norm=None, | ||||
|         norm_type=2.0, | ||||
|         scale_grad_by_freq=False, | ||||
|         sparse=False, | ||||
|         _weight=None, | ||||
|         device=None, | ||||
|         dtype=None, | ||||
|     ): | ||||
|         self.process_group = process_group | ||||
|         self.tp_rank = process_group.rank() | ||||
|         self.tp_world_size = process_group.size() | ||||
|  | ||||
|         self.original_num_embeddings = num_embeddings | ||||
|  | ||||
|         assert num_embeddings % self.tp_world_size == 0 | ||||
|         block_size = num_embeddings // self.tp_world_size | ||||
|         # inputs in `[min_id, max_id[` are handled by `self` to get embeddings | ||||
|         self.min_id = self.tp_rank * block_size | ||||
|         self.max_id = (self.tp_rank + 1) * block_size | ||||
|  | ||||
|         # Additional entry that will map to zero | ||||
|         # Used for masking | ||||
|         self.null_idx = block_size | ||||
|  | ||||
|         super().__init__( | ||||
|             block_size, | ||||
|             embedding_dim, | ||||
|             padding_idx=padding_idx, | ||||
|             max_norm=max_norm, | ||||
|             norm_type=norm_type, | ||||
|             scale_grad_by_freq=scale_grad_by_freq, | ||||
|             sparse=sparse, | ||||
|             _weight=_weight, | ||||
|             device=device, | ||||
|             dtype=dtype, | ||||
|         ) | ||||
|  | ||||
|     def add_null_idx(self): | ||||
|         """Additional 0 entry used for masking""" | ||||
|         self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) | ||||
|  | ||||
|     def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         # default all out of bounds values to `self.null_idx` that will then be mapped to 0 | ||||
|         # translate for [0, self.max_id - self.min_id[ | ||||
|         input = torch.where( | ||||
|             (self.min_id > input) | (input >= self.max_id), | ||||
|             self.null_idx, | ||||
|             input - self.min_id, | ||||
|         ) | ||||
|         out = super().forward(input) | ||||
|         torch.distributed.all_reduce(out, group=self.process_group) | ||||
|         return out | ||||
|  | ||||
|  | ||||
| class PositionRotaryEmbedding(RotaryEmbedding): | ||||
|     def _update_cos_sin_cache(self, dtype, device, seqlen): | ||||
|         # Reset the tables if the sequence length has changed, | ||||
|         # or if we're on a new device (possibly due to tracing for instance) | ||||
|         if ( | ||||
|             seqlen > self._seq_len_cached | ||||
|             or self._cos_cached.device != device | ||||
|             or self._cos_cached.dtype != dtype | ||||
|         ): | ||||
|             self._seq_len_cached = seqlen | ||||
|             t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) | ||||
|             freqs = torch.outer(t, self.inv_freq.to(device=t.device)) | ||||
|             self._cos_cached = torch.cos(freqs).to(dtype) | ||||
|             self._sin_cached = torch.sin(freqs).to(dtype) | ||||
|  | ||||
|     def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): | ||||
|         """ | ||||
|         Return cos and sin for the asked position ids | ||||
|         """ | ||||
|  | ||||
|         self._update_cos_sin_cache(dtype, position_ids.device, max_s) | ||||
|  | ||||
|         cos = torch.index_select(self._cos_cached, 0, position_ids) | ||||
|         sin = torch.index_select(self._sin_cached, 0, position_ids) | ||||
|         return cos.unsqueeze(1), sin.unsqueeze(1) | ||||
|  | ||||
|     def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): | ||||
|         rotary_dim = cos.shape[-1] | ||||
|         q1 = qkv[:, 0, :, :rotary_dim] | ||||
|         q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] | ||||
|         k1 = qkv[:, 1, :, :rotary_dim] | ||||
|         k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] | ||||
|  | ||||
|         rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) | ||||
|         rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) | ||||
|         return qkv | ||||
|  | ||||
|  | ||||
| class FlashLlamaAttention(torch.nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_heads, | ||||
|         hidden_size, | ||||
|         process_group=None, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.num_heads = num_heads | ||||
|         self.hidden_size = hidden_size | ||||
|         self.head_size = hidden_size // num_heads | ||||
|  | ||||
|         self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) | ||||
|         self.softmax_scale = self.head_size ** (-0.5) | ||||
|  | ||||
|         if process_group is None: | ||||
|             self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False) | ||||
|             self.o_proj = FastLinear(hidden_size, hidden_size, bias=False) | ||||
|         else: | ||||
|             self.num_heads = self.num_heads // process_group.size() | ||||
|             self.query_key_value = TensorParallelColumnLinear( | ||||
|                 hidden_size, | ||||
|                 3 * hidden_size, | ||||
|                 bias=False, | ||||
|                 process_group=process_group, | ||||
|             ) | ||||
|             self.o_proj = TensorParallelRowLinear( | ||||
|                 hidden_size, | ||||
|                 hidden_size, | ||||
|                 bias=False, | ||||
|                 process_group=process_group, | ||||
|             ) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states, | ||||
|         cos, | ||||
|         sin, | ||||
|         cu_seqlens, | ||||
|         max_s, | ||||
|         layer_past, | ||||
|         layer_past_present_indices, | ||||
|         cu_seqlens_q, | ||||
|     ): | ||||
|         qkv = self.query_key_value(hidden_states) | ||||
|         qkv = qkv.view(-1, 3, self.num_heads, self.head_size) | ||||
|         qkv_rot = self.rotary_emb(qkv, cos, sin) | ||||
|  | ||||
|         # Prefill | ||||
|         if layer_past_present_indices is None: | ||||
|             # Copy to layer past | ||||
|             layer_past[...] = qkv_rot[:, 1:] | ||||
|  | ||||
|             # output | ||||
|             attn_output = torch.empty_like(qkv_rot[:, 0]) | ||||
|             # flash attention | ||||
|             flash_attn_cuda.fwd( | ||||
|                 qkv_rot[:, 0], | ||||
|                 qkv_rot[:, 1], | ||||
|                 qkv_rot[:, 2], | ||||
|                 attn_output, | ||||
|                 cu_seqlens, | ||||
|                 cu_seqlens, | ||||
|                 max_s, | ||||
|                 max_s, | ||||
|                 0.0, | ||||
|                 self.softmax_scale, | ||||
|                 False, | ||||
|                 True, | ||||
|                 False, | ||||
|                 0, | ||||
|                 None, | ||||
|             ) | ||||
|         # Decode | ||||
|         else: | ||||
|             query = qkv_rot[:, 0] | ||||
|             # Add present to the layer_past tensor at the correct indices | ||||
|             layer_past[layer_past_present_indices] = qkv_rot[:, 1:] | ||||
|  | ||||
|             # output | ||||
|             attn_output = torch.empty_like(query) | ||||
|             # flash attention | ||||
|             flash_attn_cuda.fwd( | ||||
|                 query, | ||||
|                 layer_past[:, 0], | ||||
|                 layer_past[:, 1], | ||||
|                 attn_output, | ||||
|                 cu_seqlens_q, | ||||
|                 cu_seqlens, | ||||
|                 1, | ||||
|                 max_s, | ||||
|                 0.0, | ||||
|                 self.softmax_scale, | ||||
|                 False, | ||||
|                 False, | ||||
|                 False, | ||||
|                 0, | ||||
|                 None, | ||||
|             ) | ||||
|  | ||||
|         return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) | ||||
|  | ||||
|  | ||||
| class LlamaMLP(nn.Module): | ||||
|     def __init__(self, act, hidden_size, intermediate_size, process_group=None): | ||||
|         super().__init__() | ||||
|         self.act = ( | ||||
|             ACT2FN[act] | ||||
|             if "gelu" not in act | ||||
|             else lambda x: torch.nn.functional.gelu( | ||||
|                 x, | ||||
|                 approximate="tanh" | ||||
|                 if act in ["gelu_fast", "gelu_pytorch_tanh"] | ||||
|                 else None, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         if process_group is None: | ||||
|             # Fuse gate and up proj | ||||
|             self.gate_up_proj = FastLinear( | ||||
|                 hidden_size, 2 * intermediate_size, bias=False | ||||
|             ) | ||||
|             self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False) | ||||
|             self.intermediate_size = intermediate_size | ||||
|         else: | ||||
|             # Fuse gate and up proj | ||||
|             self.gate_up_proj = TensorParallelColumnLinear( | ||||
|                 hidden_size, | ||||
|                 2 * intermediate_size, | ||||
|                 bias=False, | ||||
|                 process_group=process_group, | ||||
|             ) | ||||
|             self.down_proj = TensorParallelRowLinear( | ||||
|                 intermediate_size, | ||||
|                 hidden_size, | ||||
|                 bias=False, | ||||
|                 process_group=process_group, | ||||
|                 reduce=True, | ||||
|             ) | ||||
|             self.intermediate_size = self.down_proj.in_features | ||||
|  | ||||
|         self.process_group = process_group | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         gate_up_states = self.gate_up_proj(hidden_states) | ||||
|         gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) | ||||
|         return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) | ||||
|  | ||||
|  | ||||
| class FlashLlamaLayer(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         num_heads, | ||||
|         act, | ||||
|         hidden_size, | ||||
|         intermediate_size, | ||||
|         rms_norm_eps, | ||||
|         process_group=None, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group) | ||||
|         self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group) | ||||
|  | ||||
|         self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) | ||||
|         self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states, | ||||
|         residual, | ||||
|         cos, | ||||
|         sin, | ||||
|         cu_seqlens, | ||||
|         max_s, | ||||
|         layer_past, | ||||
|         layer_past_present_indices, | ||||
|         cu_seqlens_q, | ||||
|     ): | ||||
|         normed_hidden_states, res = self.input_layernorm(hidden_states, residual) | ||||
|  | ||||
|         # Self Attention | ||||
|         attn_output = self.self_attn( | ||||
|             normed_hidden_states, | ||||
|             cos, | ||||
|             sin, | ||||
|             cu_seqlens, | ||||
|             max_s, | ||||
|             layer_past, | ||||
|             layer_past_present_indices, | ||||
|             cu_seqlens_q, | ||||
|         ) | ||||
|  | ||||
|         # faster post attention rms norm | ||||
|         normed_attn_res_output, attn_res = self.post_attention_layernorm( | ||||
|             attn_output, res | ||||
|         ) | ||||
|  | ||||
|         mlp_output = self.mlp(normed_attn_res_output) | ||||
|  | ||||
|         return mlp_output, attn_res | ||||
|  | ||||
|  | ||||
| class FlashLlamaModel(torch.nn.Module): | ||||
|     def __init__(self, config, process_group=None): | ||||
|         super(FlashLlamaModel, self).__init__() | ||||
|         self.config = config | ||||
|  | ||||
|         self.tp_embeddings = False | ||||
|         if process_group is not None: | ||||
|             self.tp_rank = process_group.rank() | ||||
|             self.tp_world_size = process_group.size() | ||||
|             if config.vocab_size % self.tp_world_size == 0: | ||||
|                 self.tp_embeddings = True | ||||
|  | ||||
|         if self.tp_embeddings: | ||||
|             self.embed_tokens = TensorParallelEmbedding( | ||||
|                 config.vocab_size, config.hidden_size, process_group=process_group | ||||
|             ) | ||||
|         else: | ||||
|             self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) | ||||
|  | ||||
|         self.layers = nn.ModuleList( | ||||
|             [ | ||||
|                 FlashLlamaLayer( | ||||
|                     config.num_attention_heads, | ||||
|                     config.hidden_act, | ||||
|                     config.hidden_size, | ||||
|                     config.intermediate_size, | ||||
|                     config.rms_norm_eps, | ||||
|                     process_group, | ||||
|                 ) | ||||
|                 for _ in range(config.num_hidden_layers) | ||||
|             ] | ||||
|         ) | ||||
|         self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|  | ||||
|         self.gradient_checkpointing = False | ||||
|  | ||||
|         self.head_size = self.layers[0].self_attn.head_size | ||||
|         self.num_heads = self.layers[0].self_attn.num_heads | ||||
|  | ||||
|     def post_load_weights(self): | ||||
|         if isinstance(self.embed_tokens, TensorParallelEmbedding): | ||||
|             self.embed_tokens.add_null_idx() | ||||
|         for layer in self.layers: | ||||
|             layer: FlashLlamaLayer | ||||
|             layer.self_attn.query_key_value.transpose_weight() | ||||
|             layer.self_attn.o_proj.transpose_weight() | ||||
|             layer.mlp.gate_up_proj.transpose_weight() | ||||
|             layer.mlp.down_proj.transpose_weight() | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         input_ids, | ||||
|         position_ids, | ||||
|         cu_seqlens, | ||||
|         max_s, | ||||
|         past_key_values=None, | ||||
|     ): | ||||
|         hidden_states = self.embed_tokens(input_ids) | ||||
|  | ||||
|         # Prefill | ||||
|         if past_key_values is None: | ||||
|             # Create past tensor | ||||
|             past_key_values = hidden_states.new_empty( | ||||
|                 ( | ||||
|                     len(self.layers), | ||||
|                     len(hidden_states), | ||||
|                     2, | ||||
|                     self.num_heads, | ||||
|                     self.head_size, | ||||
|                 ) | ||||
|             ) | ||||
|             layer_past_present_indices = None | ||||
|             cu_seqlens_q = None | ||||
|         # Decode | ||||
|         else: | ||||
|             # Create indices from cumulative sequence lengths | ||||
|             layer_past_present_indices = cu_seqlens[1:] - 1 | ||||
|             cu_seqlens_q = torch.arange( | ||||
|                 cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device | ||||
|             ) | ||||
|  | ||||
|         # Get rotary cos and sin for this forward | ||||
|         # Avoid to index in each layer | ||||
|         cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( | ||||
|             position_ids, max_s, hidden_states.dtype | ||||
|         ) | ||||
|  | ||||
|         residual = None | ||||
|         for i, layer in enumerate(self.layers): | ||||
|             hidden_states, residual = layer( | ||||
|                 hidden_states, | ||||
|                 residual, | ||||
|                 cos, | ||||
|                 sin, | ||||
|                 cu_seqlens, | ||||
|                 max_s, | ||||
|                 past_key_values[i], | ||||
|                 layer_past_present_indices, | ||||
|                 cu_seqlens_q, | ||||
|             ) | ||||
|  | ||||
|         hidden_states, _ = self.norm(hidden_states, residual) | ||||
|  | ||||
|         return hidden_states, past_key_values | ||||
|  | ||||
|  | ||||
| class FlashLlamaForCausalLM(torch.nn.Module): | ||||
|     def __init__(self, config, process_group=None): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.process_group = process_group | ||||
|         if self.process_group is not None: | ||||
|             self.world_size = self.process_group.size() | ||||
|             self.rank = self.process_group.rank() | ||||
|         else: | ||||
|             self.world_size = 1 | ||||
|             self.rank = 0 | ||||
|  | ||||
|         self.model = FlashLlamaModel(config, process_group) | ||||
|  | ||||
|         if self.model.tp_embeddings: | ||||
|             self.lm_head = FastLinear( | ||||
|                 config.hidden_size, | ||||
|                 config.vocab_size // process_group.size(), | ||||
|                 bias=False, | ||||
|             ) | ||||
|         else: | ||||
|             self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) | ||||
|  | ||||
|     def post_load_weights(self): | ||||
|         self.model.post_load_weights() | ||||
|         self.lm_head.transpose_weight() | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         input_ids, | ||||
|         position_ids, | ||||
|         cu_seqlens, | ||||
|         max_s, | ||||
|         past_key_values=None, | ||||
|     ): | ||||
|         hidden_states, present = self.model( | ||||
|             input_ids, position_ids, cu_seqlens, max_s, past_key_values | ||||
|         ) | ||||
|         logits = self.lm_head(hidden_states) | ||||
|  | ||||
|         if self.model.tp_embeddings: | ||||
|             # Logits are sharded, so we need to gather them | ||||
|             world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] | ||||
|             torch.distributed.all_gather(world_logits, logits, group=self.process_group) | ||||
|             world_logits = torch.cat(world_logits, dim=1) | ||||
|  | ||||
|             return world_logits, present | ||||
|         return logits, present | ||||
| @@ -44,6 +44,8 @@ class FlashCausalLMBatch(Batch): | ||||
|  | ||||
|     # Lengths of all generations present in the batch | ||||
|     input_lengths: List[int] | ||||
|     offsets: List[Optional[int]] | ||||
|     token_offsets: List[Optional[int]] | ||||
|  | ||||
|     # Generation helpers | ||||
|     next_token_choosers: List[NextTokenChooser] | ||||
| @@ -67,6 +69,8 @@ class FlashCausalLMBatch(Batch): | ||||
|         max_seqlen = 0 | ||||
|  | ||||
|         input_lengths = [] | ||||
|         offsets = [] | ||||
|         token_offsets = [] | ||||
|         all_input_ids = [] | ||||
|         all_input_ids_tensor = [] | ||||
|  | ||||
| @@ -84,6 +88,8 @@ class FlashCausalLMBatch(Batch): | ||||
|             input_length = len(tokenized_input) | ||||
|             max_seqlen = max(max_seqlen, input_length) | ||||
|             input_lengths.append(input_length) | ||||
|             offsets.append(None) | ||||
|             token_offsets.append(None) | ||||
|             all_input_ids.append(tokenized_input) | ||||
|  | ||||
|             tokenized_input = torch.tensor(tokenized_input, device=device) | ||||
| @@ -120,6 +126,8 @@ class FlashCausalLMBatch(Batch): | ||||
|             max_seqlen=max_seqlen, | ||||
|             past_key_values=None, | ||||
|             input_lengths=input_lengths, | ||||
|             offsets=offsets, | ||||
|             token_offsets=token_offsets, | ||||
|             all_input_ids=all_input_ids, | ||||
|             all_input_ids_tensor=all_input_ids_tensor, | ||||
|             next_token_choosers=next_token_choosers, | ||||
| @@ -132,6 +140,8 @@ class FlashCausalLMBatch(Batch): | ||||
|         # Batch attributes | ||||
|         requests = [] | ||||
|         input_lengths = [] | ||||
|         offsets = [] | ||||
|         token_offsets = [] | ||||
|         all_input_ids = [] | ||||
|         all_input_ids_tensor = [] | ||||
|         next_token_choosers = [] | ||||
| @@ -150,6 +160,8 @@ class FlashCausalLMBatch(Batch): | ||||
|         for i, batch in enumerate(batches): | ||||
|             requests.extend(batch.requests) | ||||
|             input_lengths.extend(batch.input_lengths) | ||||
|             offsets.extend(batch.offsets) | ||||
|             token_offsets.extend(batch.token_offsets) | ||||
|             all_input_ids.extend(batch.all_input_ids) | ||||
|             all_input_ids_tensor.extend(batch.all_input_ids_tensor) | ||||
|             next_token_choosers.extend(batch.next_token_choosers) | ||||
| @@ -182,6 +194,8 @@ class FlashCausalLMBatch(Batch): | ||||
|             max_seqlen=max_seqlen, | ||||
|             past_key_values=past_key_values, | ||||
|             input_lengths=input_lengths, | ||||
|             offsets=offsets, | ||||
|             token_offsets=token_offsets, | ||||
|             all_input_ids=all_input_ids, | ||||
|             all_input_ids_tensor=all_input_ids_tensor, | ||||
|             next_token_choosers=next_token_choosers, | ||||
| @@ -279,6 +293,8 @@ class FlashCausalLM(Model): | ||||
|         next_batch_max_seqlen = 0 | ||||
|         next_batch_past_key_values = [] | ||||
|         next_batch_input_lengths = [] | ||||
|         next_batch_offsets = [] | ||||
|         next_batch_token_offsets = [] | ||||
|         next_batch_all_input_ids = [] | ||||
|         next_batch_all_input_ids_tensor = [] | ||||
|  | ||||
| @@ -292,6 +308,8 @@ class FlashCausalLM(Model): | ||||
|         iterator = zip( | ||||
|             batch.requests, | ||||
|             batch.input_lengths, | ||||
|             batch.offsets, | ||||
|             batch.token_offsets, | ||||
|             batch.next_token_choosers, | ||||
|             batch.stopping_criterias, | ||||
|             batch.all_input_ids, | ||||
| @@ -302,6 +320,8 @@ class FlashCausalLM(Model): | ||||
|         for i, ( | ||||
|             request, | ||||
|             input_length, | ||||
|             offset, | ||||
|             token_offset, | ||||
|             next_token_chooser, | ||||
|             stopping_criteria, | ||||
|             all_input_ids, | ||||
| @@ -334,8 +354,10 @@ class FlashCausalLM(Model): | ||||
|  | ||||
|             # Generated token | ||||
|             next_token_logprob = logprobs[-1, next_token_id_item] | ||||
|             next_token_text = self.decode_token( | ||||
|                 next_token_id_item, | ||||
|             next_token_text, offset, token_offset = self.decode_token( | ||||
|                 all_input_ids, | ||||
|                 offset, | ||||
|                 token_offset, | ||||
|             ) | ||||
|  | ||||
|             # Evaluate stopping criteria | ||||
| @@ -376,6 +398,8 @@ class FlashCausalLM(Model): | ||||
|                     next_batch_cu_seqlens[-1] + new_input_length | ||||
|                 ) | ||||
|                 next_batch_input_lengths.append(new_input_length) | ||||
|                 next_batch_offsets.append(offset) | ||||
|                 next_batch_token_offsets.append(token_offset) | ||||
|                 next_batch_all_input_ids.append(all_input_ids) | ||||
|                 next_batch_all_input_ids_tensor.append(all_input_ids_tensor) | ||||
|                 next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) | ||||
| @@ -452,6 +476,8 @@ class FlashCausalLM(Model): | ||||
|             max_seqlen=next_batch_max_seqlen, | ||||
|             past_key_values=next_batch_past_key_values, | ||||
|             input_lengths=next_batch_input_lengths, | ||||
|             offsets=next_batch_offsets, | ||||
|             token_offsets=next_batch_token_offsets, | ||||
|             all_input_ids=next_batch_all_input_ids, | ||||
|             all_input_ids_tensor=next_batch_all_input_ids_tensor, | ||||
|             next_token_choosers=next_batch_next_token_choosers, | ||||
|   | ||||
							
								
								
									
										303
									
								
								server/text_generation_server/models/flash_llama.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										303
									
								
								server/text_generation_server/models/flash_llama.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,303 @@ | ||||
| import torch | ||||
| import torch.distributed | ||||
|  | ||||
| from accelerate import init_empty_weights | ||||
| from opentelemetry import trace | ||||
| from pathlib import Path | ||||
| from safetensors import safe_open | ||||
| from transformers import AutoConfig | ||||
| from transformers.models.llama import LlamaTokenizer | ||||
| from typing import Optional, List | ||||
|  | ||||
| from text_generation_server.models import FlashCausalLM | ||||
| from text_generation_server.models.custom_modeling.flash_llama_modeling import ( | ||||
|     FlashLlamaForCausalLM, | ||||
|     TensorParallelEmbedding, | ||||
|     TensorParallelRowLinear, | ||||
|     TensorParallelColumnLinear, | ||||
| ) | ||||
| from text_generation_server.utils import ( | ||||
|     initialize_torch_distributed, | ||||
|     weight_files, | ||||
|     download_weights, | ||||
|     weight_hub_files, | ||||
|     LocalEntryNotFoundError, | ||||
| ) | ||||
|  | ||||
| tracer = trace.get_tracer(__name__) | ||||
|  | ||||
|  | ||||
| class FlashLlama(FlashCausalLM): | ||||
|     def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): | ||||
|         if torch.cuda.is_available(): | ||||
|             device = torch.device("cuda") | ||||
|             dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | ||||
|         else: | ||||
|             raise NotImplementedError("FlashLlama is only available on GPU") | ||||
|  | ||||
|         if quantize: | ||||
|             raise NotImplementedError("FlashLlama does not support quantization") | ||||
|  | ||||
|         tokenizer = LlamaTokenizer.from_pretrained( | ||||
|             model_id, | ||||
|             revision=revision, | ||||
|             padding_side="left", | ||||
|             truncation_side="left", | ||||
|         ) | ||||
|  | ||||
|         config = AutoConfig.from_pretrained( | ||||
|             model_id, | ||||
|             revision=revision, | ||||
|         ) | ||||
|  | ||||
|         # We do not use from_pretrained as we modified the model internal module layout | ||||
|         try: | ||||
|             filenames = weight_files(model_id, revision, ".bin") | ||||
|         # Local files not found | ||||
|         except LocalEntryNotFoundError: | ||||
|             hub_files = weight_hub_files(model_id, revision, ".bin") | ||||
|             filenames = download_weights(hub_files, model_id, revision) | ||||
|  | ||||
|         with init_empty_weights(): | ||||
|             model = FlashLlamaForCausalLM(config) | ||||
|  | ||||
|         self.load_weights(model, filenames, device, dtype) | ||||
|         self.model = model.eval() | ||||
|  | ||||
|         super(FlashCausalLM, self).__init__( | ||||
|             tokenizer=tokenizer, | ||||
|             device=device, | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def load_weights( | ||||
|         model, | ||||
|         filenames: List[Path], | ||||
|         device: torch.device, | ||||
|         dtype: torch.dtype, | ||||
|     ): | ||||
|         for filename in filenames: | ||||
|             state_dict = torch.load(filename, map_location="cpu") | ||||
|             for key, value in state_dict.items(): | ||||
|                 value = value.to(device).to(dtype) | ||||
|  | ||||
|                 layer_name = ".".join(key.split(".")[:4]) | ||||
|  | ||||
|                 # Fused qkv | ||||
|                 if "q_proj" in key or "k_proj" in key or "v_proj" in key: | ||||
|                     final_key = layer_name + ".query_key_value.weight" | ||||
|  | ||||
|                 # Fused gate and up projs | ||||
|                 elif "gate_proj" in key or "up_proj" in key: | ||||
|                     final_key = layer_name + ".gate_up_proj.weight" | ||||
|                 else: | ||||
|                     final_key = key | ||||
|  | ||||
|                 module_name, param_name = final_key.rsplit(".", 1) | ||||
|                 module = model.get_submodule(module_name) | ||||
|  | ||||
|                 try: | ||||
|                     current_parameter_tensor = module._parameters[param_name] | ||||
|                 except KeyError: | ||||
|                     current_parameter_tensor = None | ||||
|  | ||||
|                 if current_parameter_tensor is not None: | ||||
|                     if current_parameter_tensor.device == torch.device("meta"): | ||||
|                         # Init qkv | ||||
|                         if "query_key_value" in final_key: | ||||
|                             module._parameters[param_name] = value.new_empty( | ||||
|                                 (value.shape[0] * 3, value.shape[1]) | ||||
|                             ) | ||||
|                         # Init gate and up proj | ||||
|                         elif "gate_up_proj" in final_key: | ||||
|                             module._parameters[param_name] = value.new_empty( | ||||
|                                 (value.shape[0] * 2, value.shape[1]) | ||||
|                             ) | ||||
|  | ||||
|                     # Copy to correct slice | ||||
|                     if "q_proj" in key: | ||||
|                         module._parameters[param_name][: value.shape[0]] = value | ||||
|                     elif "k_proj" in key: | ||||
|                         module._parameters[param_name][ | ||||
|                             value.shape[0] : value.shape[0] * 2 | ||||
|                         ] = value | ||||
|                     elif "v_proj" in key: | ||||
|                         module._parameters[param_name][value.shape[0] * 2 :] = value | ||||
|                     elif "gate_proj" in key: | ||||
|                         module._parameters[param_name][: value.shape[0]] = value | ||||
|                     elif "up_proj" in key: | ||||
|                         module._parameters[param_name][value.shape[0] :] = value | ||||
|                     else: | ||||
|                         if current_parameter_tensor.shape != value.shape: | ||||
|                             raise ValueError( | ||||
|                                 f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" | ||||
|                             ) | ||||
|                         module._parameters[param_name] = value | ||||
|                 else: | ||||
|                     module._buffers[param_name] = value | ||||
|  | ||||
|                 del value | ||||
|  | ||||
|         torch.cuda.empty_cache() | ||||
|         model.post_load_weights() | ||||
|  | ||||
|  | ||||
| class FlashLlamaSharded(FlashLlama): | ||||
|     def __init__( | ||||
|         self, model_id: str, revision: Optional[str] = None, quantize: bool = False | ||||
|     ): | ||||
|         self.process_group, self.rank, self.world_size = initialize_torch_distributed() | ||||
|         self.master = self.rank == 0 | ||||
|         if torch.cuda.is_available(): | ||||
|             device = torch.device(f"cuda:{self.rank}") | ||||
|             dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | ||||
|         else: | ||||
|             raise NotImplementedError("FlashLlama is only available on GPU") | ||||
|  | ||||
|         if quantize: | ||||
|             raise NotImplementedError("FlashLlama does not support quantization") | ||||
|  | ||||
|         tokenizer = LlamaTokenizer.from_pretrained( | ||||
|             model_id, | ||||
|             revision=revision, | ||||
|             padding_side="left", | ||||
|             truncation_side="left", | ||||
|         ) | ||||
|  | ||||
|         config = AutoConfig.from_pretrained( | ||||
|             model_id, | ||||
|             revision=revision, | ||||
|         ) | ||||
|  | ||||
|         torch.distributed.barrier(group=self.process_group) | ||||
|         filenames = weight_files(model_id, revision=revision, extension=".safetensors") | ||||
|  | ||||
|         with init_empty_weights(): | ||||
|             model = FlashLlamaForCausalLM(config, process_group=self.process_group) | ||||
|  | ||||
|         torch.distributed.barrier(group=self.process_group) | ||||
|         self.load_weights( | ||||
|             model, | ||||
|             filenames, | ||||
|             quantize=quantize, | ||||
|             device=device, | ||||
|             dtype=dtype, | ||||
|             rank=self.rank, | ||||
|             world_size=self.world_size, | ||||
|         ) | ||||
|         self.model = model.eval() | ||||
|         torch.distributed.barrier(group=self.process_group) | ||||
|         super(FlashCausalLM, self).__init__( | ||||
|             tokenizer=tokenizer, | ||||
|             device=device, | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def load_weights( | ||||
|         model, | ||||
|         filenames: List[str], | ||||
|         quantize: bool, | ||||
|         device: torch.device, | ||||
|         dtype: torch.dtype, | ||||
|         rank: int, | ||||
|         world_size: int, | ||||
|     ): | ||||
|         for file in filenames: | ||||
|             with safe_open( | ||||
|                 file, framework="pt", device=str(device) if not quantize else "cpu" | ||||
|             ) as f: | ||||
|                 for name in f.keys(): | ||||
|                     slice_ = f.get_slice(name) | ||||
|  | ||||
|                     layer_name = ".".join(name.split(".")[:4]) | ||||
|  | ||||
|                     # Fused qkv | ||||
|                     if "q_proj" in name or "k_proj" in name or "v_proj" in name: | ||||
|                         final_name = layer_name + ".query_key_value.weight" | ||||
|  | ||||
|                     # Fused gate and up projs | ||||
|                     elif "gate_proj" in name or "up_proj" in name: | ||||
|                         final_name = layer_name + ".gate_up_proj.weight" | ||||
|                     else: | ||||
|                         final_name = name | ||||
|  | ||||
|                     module_name, param_name = final_name.rsplit(".", 1) | ||||
|                     module = model.get_submodule(module_name) | ||||
|  | ||||
|                     if isinstance(module, TensorParallelColumnLinear): | ||||
|                         size = slice_.get_shape()[0] | ||||
|                         block_size = size // world_size | ||||
|                         start = rank * block_size | ||||
|                         stop = (rank + 1) * block_size | ||||
|                         tensor = slice_[start:stop] | ||||
|                     elif isinstance(module, TensorParallelRowLinear): | ||||
|                         size = slice_.get_shape()[1] | ||||
|                         block_size = size // world_size | ||||
|                         start = rank * block_size | ||||
|                         stop = (rank + 1) * block_size | ||||
|                         tensor = slice_[:, start:stop] | ||||
|                     elif isinstance(module, TensorParallelEmbedding): | ||||
|                         size = slice_.get_shape()[0] | ||||
|                         block_size = size // world_size | ||||
|                         start = rank * block_size | ||||
|                         stop = (rank + 1) * block_size | ||||
|                         tensor = slice_[start:stop] | ||||
|                     elif name == "lm_head.weight" and model.model.tp_embeddings: | ||||
|                         size = slice_.get_shape()[0] | ||||
|                         block_size = size // world_size | ||||
|                         start = rank * block_size | ||||
|                         stop = (rank + 1) * block_size | ||||
|                         tensor = slice_[start:stop] | ||||
|                     else: | ||||
|                         try: | ||||
|                             tensor = slice_[:] | ||||
|                         except: | ||||
|                             tensor = f.get_tensor(name) | ||||
|  | ||||
|                     tensor = tensor.contiguous().to(dtype) | ||||
|  | ||||
|                     try: | ||||
|                         current_parameter_tensor = module._parameters[param_name] | ||||
|                     except KeyError: | ||||
|                         current_parameter_tensor = None | ||||
|  | ||||
|                     if current_parameter_tensor is not None: | ||||
|                         if current_parameter_tensor.device == torch.device("meta"): | ||||
|                             # Init qkv | ||||
|                             if "query_key_value" in final_name: | ||||
|                                 module._parameters[param_name] = tensor.new_empty( | ||||
|                                     (tensor.shape[0] * 3, tensor.shape[1]) | ||||
|                                 ) | ||||
|                             # Init gate and up proj | ||||
|                             elif "gate_up_proj" in final_name: | ||||
|                                 module._parameters[param_name] = tensor.new_empty( | ||||
|                                     (tensor.shape[0] * 2, tensor.shape[1]) | ||||
|                                 ) | ||||
|  | ||||
|                         # Init gate and up proj | ||||
|                         if "q_proj" in name: | ||||
|                             module._parameters[param_name][: tensor.shape[0]] = tensor | ||||
|                         elif "k_proj" in name: | ||||
|                             module._parameters[param_name][ | ||||
|                                 tensor.shape[0] : tensor.shape[0] * 2 | ||||
|                             ] = tensor | ||||
|                         elif "v_proj" in name: | ||||
|                             module._parameters[param_name][ | ||||
|                                 tensor.shape[0] * 2 : | ||||
|                             ] = tensor | ||||
|                         elif "gate_proj" in name: | ||||
|                             module._parameters[param_name][: tensor.shape[0]] = tensor | ||||
|                         elif "up_proj" in name: | ||||
|                             module._parameters[param_name][tensor.shape[0] :] = tensor | ||||
|                         else: | ||||
|                             if current_parameter_tensor.shape != tensor.shape: | ||||
|                                 raise ValueError( | ||||
|                                     f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" | ||||
|                                 ) | ||||
|  | ||||
|                             module._parameters[param_name] = tensor | ||||
|  | ||||
|                     else: | ||||
|                         module._buffers[param_name] = tensor | ||||
|         torch.cuda.empty_cache() | ||||
|         model.post_load_weights() | ||||
| @@ -93,7 +93,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): | ||||
|         inputs = [] | ||||
|         next_token_choosers = [] | ||||
|         stopping_criterias = [] | ||||
|         input_lengths = [] | ||||
|         offsets = [] | ||||
|         token_offsets = [] | ||||
|  | ||||
|         # Parse batch | ||||
|         max_truncation = 0 | ||||
| @@ -101,7 +102,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): | ||||
|         for r in pb.requests: | ||||
|             # Add escape_custom_split_sequence to the CausalLMBatch logic | ||||
|             inputs.append(escape_custom_split_sequence(r.inputs)) | ||||
|             input_lengths.append(r.input_length) | ||||
|             offsets.append(None) | ||||
|             token_offsets.append(None) | ||||
|             next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) | ||||
|             stopping_criteria = StoppingCriteria.from_pb( | ||||
|                 r.stopping_parameters, tokenizer | ||||
| @@ -146,6 +148,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): | ||||
|             past_key_values=None, | ||||
|             all_input_ids=all_input_ids, | ||||
|             input_lengths=input_lengths, | ||||
|             offsets=offsets, | ||||
|             token_offsets=token_offsets, | ||||
|             next_token_choosers=next_token_choosers, | ||||
|             stopping_criterias=stopping_criterias, | ||||
|             size=pb.size, | ||||
|   | ||||
| @@ -15,15 +15,6 @@ class Model(ABC): | ||||
|         self.all_special_ids = set(tokenizer.all_special_ids) | ||||
|         self.device = device | ||||
|  | ||||
|         # see `decode_token` method | ||||
|         self.tokenizer.add_special_tokens( | ||||
|             {"additional_special_tokens": ["<decode-token>"]} | ||||
|         ) | ||||
|         self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids( | ||||
|             "<decode-token>" | ||||
|         ) | ||||
|         self.special_decode_token_length = len("<decode-token>") | ||||
|  | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def batch_type(self) -> Type[B]: | ||||
| @@ -33,11 +24,38 @@ class Model(ABC): | ||||
|     def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def decode_token(self, token_id: int) -> str: | ||||
|     def decode_token( | ||||
|         self, | ||||
|         all_input_ids: List[int], | ||||
|         offset: Optional[int] = None, | ||||
|         token_offset: Optional[int] = None, | ||||
|     ) -> Tuple[str, Optional[int], Optional[int]]: | ||||
|         """Hack to hopefully support generate_stream for the maximum number of tokenizers""" | ||||
|         # append token to special decode token and decode both | ||||
|         result = self.tokenizer.decode( | ||||
|             [self.special_decode_token_id, token_id], skip_special_tokens=False | ||||
|         if all_input_ids[-1] in self.all_special_ids: | ||||
|             return ( | ||||
|                 self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False), | ||||
|                 None, | ||||
|                 None, | ||||
|             ) | ||||
|  | ||||
|         if token_offset is None: | ||||
|             token_offset = len(all_input_ids) - 3 | ||||
|  | ||||
|         # Decode token_offset token minus last one and token_offset tokens | ||||
|         results = self.tokenizer.batch_decode( | ||||
|             [all_input_ids[token_offset:-1], all_input_ids[token_offset:]], | ||||
|             skip_special_tokens=False, | ||||
|         ) | ||||
|         # slice to remove special decode token | ||||
|         return result[self.special_decode_token_length :] | ||||
|  | ||||
|         # default offset is only the last token | ||||
|         if offset is None: | ||||
|             offset = len(results[0]) | ||||
|  | ||||
|         # get text | ||||
|         text = results[1][offset:] | ||||
|  | ||||
|         # if text is utf-8 | ||||
|         if text and text[-1] != "<EFBFBD>": | ||||
|             return text, None, None | ||||
|         else: | ||||
|             return "", offset, token_offset | ||||
|   | ||||
| @@ -38,6 +38,8 @@ class Seq2SeqLMBatch(Batch): | ||||
|     # Lengths of all generations present in the batch | ||||
|     input_lengths: List[int] | ||||
|     decoder_input_lengths: List[int] | ||||
|     offsets: List[Optional[int]] | ||||
|     token_offsets: List[Optional[int]] | ||||
|  | ||||
|     # Generation helpers | ||||
|     next_token_choosers: List[NextTokenChooser] | ||||
| @@ -71,6 +73,8 @@ class Seq2SeqLMBatch(Batch): | ||||
|  | ||||
|         decoder_input_ids = [] | ||||
|         decoder_input_lengths = [] | ||||
|         offsets = [] | ||||
|         token_offsets = [] | ||||
|  | ||||
|         # Parse batch | ||||
|         max_truncation = 0 | ||||
| @@ -80,6 +84,8 @@ class Seq2SeqLMBatch(Batch): | ||||
|             # Decoder sequence only contains the bos_token | ||||
|             decoder_input_ids.append(tokenizer.bos_token_id) | ||||
|             decoder_input_lengths.append(1) | ||||
|             offsets.append(None) | ||||
|             token_offsets.append(None) | ||||
|             next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) | ||||
|             stopping_criteria = StoppingCriteria.from_pb( | ||||
|                 r.stopping_parameters, tokenizer | ||||
| @@ -117,6 +123,8 @@ class Seq2SeqLMBatch(Batch): | ||||
|             past_key_values=None, | ||||
|             input_lengths=input_lengths.tolist(), | ||||
|             decoder_input_lengths=decoder_input_lengths, | ||||
|             offsets=offsets, | ||||
|             token_offsets=token_offsets, | ||||
|             next_token_choosers=next_token_choosers, | ||||
|             stopping_criterias=stopping_criterias, | ||||
|             size=len(pb.requests), | ||||
| @@ -147,6 +155,8 @@ class Seq2SeqLMBatch(Batch): | ||||
|         requests = [] | ||||
|         input_lengths = [] | ||||
|         decoder_input_lengths = [] | ||||
|         offsets = [] | ||||
|         token_offsets = [] | ||||
|         next_token_choosers = [] | ||||
|         stopping_criterias = [] | ||||
|  | ||||
| @@ -166,6 +176,8 @@ class Seq2SeqLMBatch(Batch): | ||||
|             requests.extend(batch.requests) | ||||
|             input_lengths.extend(batch.input_lengths) | ||||
|             decoder_input_lengths.extend(batch.decoder_input_lengths) | ||||
|             offsets.extend(batch.offsets) | ||||
|             token_offsets.extend(batch.token_offsets) | ||||
|             next_token_choosers.extend(batch.next_token_choosers) | ||||
|             stopping_criterias.extend(batch.stopping_criterias) | ||||
|  | ||||
| @@ -303,6 +315,8 @@ class Seq2SeqLMBatch(Batch): | ||||
|             past_key_values=past_key_values, | ||||
|             input_lengths=input_lengths, | ||||
|             decoder_input_lengths=decoder_input_lengths, | ||||
|             offsets=offsets, | ||||
|             token_offsets=token_offsets, | ||||
|             next_token_choosers=next_token_choosers, | ||||
|             stopping_criterias=stopping_criterias, | ||||
|             size=total_batch_size, | ||||
| @@ -335,7 +349,7 @@ class Seq2SeqLM(Model): | ||||
|             load_in_8bit=quantize, | ||||
|         ).eval() | ||||
|         tokenizer = AutoTokenizer.from_pretrained( | ||||
|             model_id, revision=revision, padding_side="left" | ||||
|             model_id, revision=revision, padding_side="left", truncation_side="left" | ||||
|         ) | ||||
|         tokenizer.bos_token_id = self.model.config.decoder_start_token_id | ||||
|  | ||||
| @@ -422,6 +436,8 @@ class Seq2SeqLM(Model): | ||||
|  | ||||
|         # New values for next forward | ||||
|         next_batch_input_lengths = [] | ||||
|         next_batch_offsets = [] | ||||
|         next_batch_token_offsets = [] | ||||
|         next_batch_decoder_input_ids = [] | ||||
|         next_batch_decoder_input_lengths = [] | ||||
|  | ||||
| @@ -437,6 +453,8 @@ class Seq2SeqLM(Model): | ||||
|         iterator = zip( | ||||
|             batch.requests, | ||||
|             batch.input_lengths, | ||||
|             batch.offsets, | ||||
|             batch.token_offsets, | ||||
|             batch.decoder_input_lengths, | ||||
|             logits, | ||||
|             batch.next_token_choosers, | ||||
| @@ -448,6 +466,8 @@ class Seq2SeqLM(Model): | ||||
|         for i, ( | ||||
|             request, | ||||
|             input_length, | ||||
|             offset, | ||||
|             token_offset, | ||||
|             decoder_input_length, | ||||
|             logits, | ||||
|             next_token_chooser, | ||||
| @@ -466,8 +486,8 @@ class Seq2SeqLM(Model): | ||||
|             # Generated token | ||||
|             next_token_logprob = logprobs[-1, next_token_id] | ||||
|             next_token_id_squeezed = next_token_id.squeeze() | ||||
|             next_token_text = self.decode_token( | ||||
|                 next_token_id_squeezed, | ||||
|             next_token_text, offset, token_offset = self.decode_token( | ||||
|                 decoder_input_ids, offset, token_offset | ||||
|             ) | ||||
|  | ||||
|             # Evaluate stopping criteria | ||||
| @@ -495,6 +515,8 @@ class Seq2SeqLM(Model): | ||||
|                 next_batch_size += 1 | ||||
|                 next_batch_input_lengths.append(input_length) | ||||
|                 next_batch_decoder_input_lengths.append(new_decoder_input_length) | ||||
|                 next_batch_offsets.append(offset) | ||||
|                 next_batch_token_offsets.append(token_offset) | ||||
|                 next_batch_max_input_length = max( | ||||
|                     next_batch_max_input_length, input_length | ||||
|                 ) | ||||
| @@ -580,6 +602,8 @@ class Seq2SeqLM(Model): | ||||
|             past_key_values=next_batch_past_key_values, | ||||
|             input_lengths=next_batch_input_lengths, | ||||
|             decoder_input_lengths=next_batch_decoder_input_lengths, | ||||
|             offsets=next_batch_offsets, | ||||
|             token_offsets=next_batch_token_offsets, | ||||
|             next_token_choosers=next_batch_next_token_choosers, | ||||
|             stopping_criterias=next_batch_stopping_criterias, | ||||
|             size=next_batch_size, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 OlivierDehaene
					OlivierDehaene