mirror of
https://github.com/openai/gpt-oss.git
synced 2025-08-06 00:55:46 +03:00
Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
This commit is contained in:
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: 🐛 Model Issues
|
||||
url: https://huggingface.co/openai/gpt-oss-120b/discussions
|
||||
about: For general questions about the models, please use the Community feature on Hugging Face.
|
||||
- name: 💡 General Feedback
|
||||
url: https://openai.com/open-models
|
||||
about: Suggest new features on our feature request page.
|
||||
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
build
|
||||
_skbuild
|
||||
tmp*
|
||||
__pycache__
|
||||
*.egg*
|
||||
node_modules/
|
||||
*.log
|
||||
23
CMakeLists.txt
Normal file
23
CMakeLists.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
cmake_minimum_required(VERSION 3.26)
|
||||
project(gpt_oss LANGUAGES C CXX)
|
||||
|
||||
# If not defined externally, auto-detect
|
||||
if(NOT DEFINED GPTOSS_BUILD_METAL)
|
||||
if(APPLE AND CMAKE_SYSTEM_PROCESSOR MATCHES "arm64")
|
||||
message(STATUS "Apple Silicon detected → enabling GPTOSS_BUILD_METAL")
|
||||
set(GPTOSS_BUILD_METAL ON)
|
||||
else()
|
||||
message(STATUS "Non-Apple Silicon → disabling GPTOSS_BUILD_METAL")
|
||||
set(GPTOSS_BUILD_METAL OFF)
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "GPTOSS_BUILD_METAL manually set to: ${GPTOSS_BUILD_METAL}")
|
||||
endif()
|
||||
|
||||
# Now declare it as a cache variable (respects user-provided value)
|
||||
set(GPTOSS_BUILD_METAL "${GPTOSS_BUILD_METAL}" CACHE BOOL "Enable Metal backend")
|
||||
|
||||
if(GPTOSS_BUILD_METAL)
|
||||
enable_language(OBJC)
|
||||
add_subdirectory(gpt_oss/metal)
|
||||
endif()
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2025 OpenAI
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
460
README.md
Normal file
460
README.md
Normal file
@@ -0,0 +1,460 @@
|
||||
<center>
|
||||
<img alt="gpt-oss-120" src="./docs/gpt-oss.svg">
|
||||
<p align="center"><br>
|
||||
<a href="https://gpt-oss.com" target="_blank">Try gpt-oss</a> | <a href="https://cookbook.openai.com/topic/gpt-oss">Guides</a> | <a href="https://openai.com/index/gpt-oss-model-card">Model card</a>
|
||||
<br><a href="https://openai.com/index/introducing-gpt-oss/">Learn more about OpenAI's open models</a><br>
|
||||
Download <a href="https://huggingface.co/openai/gpt-oss-120b">gpt-oss-120b</a> and <a href="https://huggingface.co/openai/gpt-oss-20b">gpt-oss-20b</a> on Hugging Face</a>
|
||||
</p>
|
||||
<br>
|
||||
</center>
|
||||
|
||||
Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases.
|
||||
|
||||
We're releasing two flavors of the open models:
|
||||
|
||||
- `gpt-oss-120b` — for production, general purpose, high reasoning use cases that fits into a single H100 GPU (117B parameters with 5.1B active parameters)
|
||||
- `gpt-oss-20b` — for lower latency, and local or specialized use cases (21B parameters with 3.6B active parameters)
|
||||
|
||||
Both models were trained on our [harmony response format][harmony] and should only be used with the harmony format as it will not work correctly otherwise.
|
||||
|
||||
### Highlights
|
||||
|
||||
- **Permissive Apache 2.0 license:** Build freely without copyleft restrictions or patent risk—ideal for experimentation, customization, and commercial deployment.
|
||||
- **Configurable reasoning effort:** Easily adjust the reasoning effort (low, medium, high) based on your specific use case and latency needs.
|
||||
- **Full chain-of-thought:** Gain complete access to the model's reasoning process, facilitating easier debugging and increased trust in outputs. It's not intended to be shown to end users.
|
||||
- **Fine-tunable:** Fully customize models to your specific use case through parameter fine-tuning.
|
||||
- **Agentic capabilities:** Use the models' native capabilities for function calling, [web browsing](#browser), [Python code execution](#python), and Structured Outputs.
|
||||
- **Native MXFP4 quantization:** The models are trained with native MXFP4 precision for the MoE layer, making `gpt-oss-120b` run on a single H100 GPU and the `gpt-oss-20b` model run within 16GB of memory.
|
||||
|
||||
### Inference examples
|
||||
|
||||
#### Transformers
|
||||
|
||||
You can use `gpt-oss-120b` and `gpt-oss-20b` with Transformers. If you use Transformers's chat template it will automatically apply the [harmony response format][harmony]. If you use `model.generate` directly, you need to apply the harmony format manually using the chat template or use our [`openai-harmony`][harmony] package.
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
import torch
|
||||
|
||||
model_id = "openai/gpt-oss-120b"
|
||||
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model=model_id,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
|
||||
]
|
||||
|
||||
outputs = pipe(
|
||||
messages,
|
||||
max_new_tokens=256,
|
||||
)
|
||||
print(outputs[0]["generated_text"][-1])
|
||||
```
|
||||
|
||||
[Learn more about how to use gpt-oss with Transformers.](https://cookbook.openai.com/articles/gpt-oss/run-transformers)
|
||||
|
||||
#### vLLM
|
||||
|
||||
vLLM recommends using [`uv`](https://docs.astral.sh/uv/) for Python dependency management. You can use vLLM to spin up an OpenAI-compatible webserver. The following command will automatically download the model and start the server.
|
||||
|
||||
```bash
|
||||
uv run --with vllm vllm serve openai/gpt-oss-20b
|
||||
```
|
||||
|
||||
[Learn more about how to use gpt-oss with vLLM.](https://cookbook.openai.com/articles/gpt-oss/run-vllm)
|
||||
|
||||
#### Pytorch / Triton / Metal
|
||||
|
||||
These implementations are largely reference implementations for educational purposes and are not expected to be run in production.
|
||||
|
||||
[Learn more below.](#reference-pytorch-implementation)
|
||||
|
||||
#### Ollama
|
||||
|
||||
If you are trying to run `gpt-oss` on consumer hardware, you can use Ollama by running the following commands after [installing Ollama](https://ollama.com/download).
|
||||
|
||||
```bash
|
||||
# gpt-oss-20b
|
||||
ollama pull gpt-oss:20b
|
||||
ollama run gpt-oss:20b
|
||||
|
||||
# gpt-oss-120b
|
||||
ollama pull gpt-oss:120b
|
||||
ollama run gpt-oss:120b
|
||||
```
|
||||
|
||||
[Learn more about how to use gpt-oss with Ollama.](https://cookbook.openai.com/articles/gpt-oss/run-locally-ollama)
|
||||
|
||||
#### LM Studio
|
||||
|
||||
If you are using [LM Studio](https://lmstudio.ai/) you can use the following commands to download.
|
||||
|
||||
```bash
|
||||
# gpt-oss-20b
|
||||
lms get openai/gpt-oss-20b
|
||||
# gpt-oss-120b
|
||||
lms get openai/gpt-oss-120b
|
||||
```
|
||||
|
||||
Check out our [awesome list](./awesome-gpt-oss.md) for a broader collection of gpt-oss resources and inference partners.
|
||||
|
||||
## About this repository
|
||||
|
||||
This repository provides a collection of reference implementations:
|
||||
|
||||
- **Inference:**
|
||||
- [`torch`](#reference-pytorch-implementation) — a non-optimized [Pytorch](https://pytorch.org/) implementation for educational purposes only. Requires at least 4x H100s because it's not optimized
|
||||
- [`triton`](#reference-triton-implementation-single-gpu) — a more optimized implementation using [Pytorch](https://pytorch.org/) & [Triton](https://github.com/triton-lang/triton) incl. using CUDA graphs and basic caching
|
||||
- [`metal`](#reference-metal-implementation) — a Metal-specific implementation for running the models on Apple Silicon hardware
|
||||
- **Tools:**
|
||||
- [`browser`](#browser) — a reference implementation of the browser tool the models got trained on
|
||||
- [`python`](#python) — a stateless reference implementation of the python tool the model got trained on
|
||||
- **Client examples:**
|
||||
- [`chat`](#terminal-chat) — a basic terminal chat application that uses the Pytorch or Triton implementations for inference along with the python and browser tools
|
||||
- [`responses_api`](#responses-api) — an example Responses API compatible server that implements the browser tool along with other Responses-compatible functionality
|
||||
|
||||
## Setup
|
||||
|
||||
### Requirements
|
||||
|
||||
- On macOS: Install the Xcode CLI tools --> `xcode-select --install`
|
||||
- On Linux: These reference implementations require CUDA
|
||||
- On Windows: These reference implementations have not been tested on Windows. Try using solutions like Ollama if you are trying to run the model locally.
|
||||
|
||||
### Installation
|
||||
|
||||
If you want to try any of the code you can install it directly from [PyPI](https://pypi.org/project/gpt-oss/)
|
||||
|
||||
```shell
|
||||
# if you just need the tools
|
||||
pip install gpt-oss
|
||||
# if you want to try the torch implementation
|
||||
pip install gpt-oss[torch]
|
||||
# if you want to try the triton implementation
|
||||
pip install gpt-oss[triton]
|
||||
```
|
||||
|
||||
If you want to modify the code or try the metal implementation set the project up locally:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/openai/gpt-oss.git
|
||||
pip install -e .[metal]
|
||||
```
|
||||
|
||||
## Download the model
|
||||
|
||||
You can download the model weights from the [Hugging Face Hub](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) directly from Hugging Face CLI:
|
||||
|
||||
```shell
|
||||
# gpt-oss-120b
|
||||
huggingface-cli download openai/gpt-oss-120b --include "original/*" --local-dir gpt-oss-120b/
|
||||
|
||||
# gpt-oss-20b
|
||||
huggingface-cli download openai/gpt-oss-20b --include "original/*" --local-dir gpt-oss-20b/
|
||||
```
|
||||
|
||||
## Reference PyTorch implementation
|
||||
|
||||
We include an inefficient reference PyTorch implementation in [gpt_oss/torch/model.py](gpt_oss/torch/model.py). This code uses basic PyTorch operators to show the exact model architecture, with a small addition of supporting tensor parallelism in MoE so that the larger model can run with this code (e.g., on 4xH100 or 2xH200). In this implementation, we upcast all weights to BF16 and run the model in BF16.
|
||||
|
||||
To run the reference implementation. Install dependencies:
|
||||
|
||||
```shell
|
||||
pip install -e .[torch]
|
||||
```
|
||||
|
||||
And then run:
|
||||
|
||||
```shell
|
||||
# On 4xH100:
|
||||
torchrun --nproc-per-node=4 -m gpt_oss.generate gpt-oss-120b/original/
|
||||
```
|
||||
|
||||
## Reference Triton implementation (single GPU)
|
||||
|
||||
We also include an optimized reference implementation that uses [an optimized triton MoE kernel](https://github.com/triton-lang/triton/tree/main/python/triton_kernels/triton_kernels) that supports MXFP4. It also has some optimization on the attention code to reduce the memory cost. To run this implementation, the nightly version of triton and torch will be installed. This version can be run on a single 80GB GPU for `gpt-oss-120b`.
|
||||
|
||||
To install the reference Triton implementation run
|
||||
|
||||
```shell
|
||||
pip install -e .[triton]
|
||||
```
|
||||
|
||||
And then run:
|
||||
|
||||
```shell
|
||||
# On 1xH100
|
||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
python -m gpt_oss.generate --backend triton gpt-oss-120b/original/
|
||||
```
|
||||
|
||||
If you encounter `torch.OutOfMemoryError` make sure to turn on the expandable allocator to avoid crashes when loading weights from the checkpoint.
|
||||
|
||||
## Reference Metal implementation
|
||||
|
||||
Additionally we are providing a reference implementation for Metal to run on Apple Silicon. This implementation is not production ready but is accurate to the Pytorch implementation.
|
||||
|
||||
The implementation will get automatically compiled when running the `.[metal]` installation on an Apple Silicon device:
|
||||
|
||||
```shell
|
||||
pip install -e .[metal]
|
||||
```
|
||||
|
||||
To perform inference you'll need to first convert the SafeTensor weights from Hugging Face into the right format using:
|
||||
|
||||
```shell
|
||||
python gpt_oss/metal/scripts/create-local-model.py -s <model_dir> -d <output_file>
|
||||
```
|
||||
|
||||
To test it you can run:
|
||||
|
||||
```shell
|
||||
python gpt_oss/metal/examples/generate.py gpt-oss-20b/metal/model.bin -p "why did the chicken cross the road?"
|
||||
```
|
||||
|
||||
## Harmony format & tools
|
||||
|
||||
Along with the model, we are also releasing a new chat format library `harmony` to interact with the model. Check [harmony.md](harmony.md) for more info about harmony.
|
||||
|
||||
We also include two system tools for the model: browsing and python container. Check [gpt_oss/tools](gpt_oss/tools) for the tool implementation.
|
||||
|
||||
## Clients
|
||||
|
||||
### Terminal Chat
|
||||
|
||||
The terminal chat application is a basic example on how to use the harmony format together with the Pytorch, Triton, and vLLM implementations. It also exposes both the python and browser tool as optional tools that can be used.
|
||||
|
||||
```bash
|
||||
usage: python -m gpt_oss.chat [-h] [-r REASONING_EFFORT] [-a] [-b] [--show-browser-results] [-p] [--developer-message DEVELOPER_MESSAGE] [-c CONTEXT] [--raw] [--backend {triton,torch,vllm}] FILE
|
||||
|
||||
Chat example
|
||||
|
||||
positional arguments:
|
||||
FILE Path to the SafeTensors checkpoint
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
-r REASONING_EFFORT, --reasoning-effort REASONING_EFFORT
|
||||
Reasoning effort (default: low)
|
||||
-a, --apply-patch Make apply_patch tool available to the model (default: False)
|
||||
-b, --browser Use browser tool (default: False)
|
||||
--show-browser-results
|
||||
Show browser results (default: False)
|
||||
-p, --python Use python tool (default: False)
|
||||
--developer-message DEVELOPER_MESSAGE
|
||||
Developer message (default: )
|
||||
-c CONTEXT, --context CONTEXT
|
||||
Max context length (default: 8192)
|
||||
--raw Raw mode (does not render Harmony encoding) (default: False)
|
||||
--backend {triton,torch,vllm}
|
||||
Inference backend (default: triton)
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The torch and triton implementation requires original checkpoint under `gpt-oss-120b/original/` and `gpt-oss-20b/original/` respectively. While vLLM uses the Hugging Face converted checkpoint under `gpt-oss-120b/` and `gpt-oss-20b/` root directory respectively.
|
||||
|
||||
### Responses API
|
||||
|
||||
We also include an example Responses API server. This server does not implement every feature and event of the Responses API but should be compatible with most of the basic use cases and serve as inspiration for anyone building their own server. Some of our inference partners are also offering their own Responses API.
|
||||
|
||||
You can start this server with the following inference backends:
|
||||
|
||||
- `triton` — uses the triton implementation
|
||||
- `metal` — uses the metal implementation on Apple Silicon only
|
||||
- `ollama` — uses the Ollama /api/generate API as a inference solution
|
||||
- `vllm` — uses your installed vllm version to perform inference
|
||||
|
||||
```bash
|
||||
usage: python -m gpt_oss.responses_api.serve [-h] [--checkpoint FILE] [--port PORT] [--inference-backend BACKEND]
|
||||
|
||||
Responses API server
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--checkpoint FILE Path to the SafeTensors checkpoint
|
||||
--port PORT Port to run the server on
|
||||
--inference-backend BACKEND Inference backend to use
|
||||
```
|
||||
|
||||
### Codex
|
||||
|
||||
We support [codex](https://github.com/openai/codex) as a client for gpt-oss. To run the 20b version, set this to `~/.codex/config.toml`:
|
||||
|
||||
```
|
||||
disable_response_storage = true
|
||||
show_reasoning_content = true
|
||||
|
||||
[model_providers.local]
|
||||
name = "local"
|
||||
base_url = "http://localhost:11434/v1"
|
||||
|
||||
[profiles.oss]
|
||||
model = "gpt-oss:20b"
|
||||
model_provider = "local"
|
||||
```
|
||||
|
||||
This will work with any chat completions-API compatible server listening on port 11434, like ollama. Start the server and point codex to the oss model:
|
||||
|
||||
```
|
||||
ollama run gpt-oss:20b
|
||||
codex -p oss
|
||||
```
|
||||
|
||||
## Tools
|
||||
|
||||
### Browser
|
||||
|
||||
> [!WARNING]
|
||||
> This implementation is purely for educational purposes and should not be used in production. You should implement your own equivalent of the [`ExaBackend`](gpt_oss/tools/simple_browser/backend.py) class with your own browsing environment.
|
||||
|
||||
Both gpt-oss models were trained with the capability to browse using the `browser` tool that exposes the following three methods:
|
||||
|
||||
- `search` to search for key phrases
|
||||
- `open` to open a particular page
|
||||
- `find` to look for contents on a page
|
||||
|
||||
#### Usage
|
||||
|
||||
To enable the browser tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_browser()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example:
|
||||
|
||||
```python
|
||||
import datetime
|
||||
from gpt_oss.tools.simple_browser import SimpleBrowserTool
|
||||
from gpt_oss.tools.simple_browser.backend import ExaBackend
|
||||
from openai_harmony import SystemContent, Message, Conversation, Role, load_harmony_encoding, HarmonyEncodingName
|
||||
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
# Exa backend requires you to have set the EXA_API_KEY environment variable
|
||||
backend = ExaBackend(
|
||||
source="web",
|
||||
)
|
||||
browser_tool = SimpleBrowserTool(backend=backend)
|
||||
|
||||
# create a basic system prompt
|
||||
system_message_content = SystemContent.new().with_conversation_start_date(
|
||||
datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
)
|
||||
|
||||
# if you want to use the browser tool
|
||||
if use_browser_tool:
|
||||
# enables the tool
|
||||
system_message_content = system_message_content.with_tools(browser_tool.tool_config)
|
||||
# alternativel you could
|
||||
system_message_content = system_message_content.with_browser()
|
||||
|
||||
# construct the system message
|
||||
system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)
|
||||
|
||||
# create the overall prompt
|
||||
messages = [system_message, Message.from_role_and_content(Role.USER, "What's the weather in SF?")]
|
||||
conversation = Conversation.from_messages(messages)
|
||||
|
||||
# convert to tokens
|
||||
token_ids = encoding.render_conversation_for_completion(conversation, Role.ASSISTANT)
|
||||
|
||||
# perform inference
|
||||
# ...
|
||||
|
||||
# parse the output
|
||||
messages = messages = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT)
|
||||
last_message = messages[-1]
|
||||
if last_message.recipient.startswith("browser"):
|
||||
# perform browser call
|
||||
response_messages = await browser_tool.process(last_message)
|
||||
|
||||
# extend the current messages and run inference again
|
||||
messages.extend(response_messages)
|
||||
```
|
||||
|
||||
#### Details
|
||||
|
||||
To control the context window size this tool use a scrollable window of text that the model can interact with. So it might fetch the first 50 lines of a page and then scroll to the next 20 lines after that. The model has also been trained to then use citations from this tool in its answers.
|
||||
|
||||
To improve performance the tool caches requests so that the model can revisit a different part of a page without having to reload the page. For that reason you should create a new browser instance for every request.
|
||||
|
||||
### Python
|
||||
|
||||
The model got trained on using a python tool to perform calculations and other actions as part of its chain-of-thought. During the training the model used a stateful tool which makes running tools between CoT loops easier. This reference implementation, however, uses a stateless mode. As a result the PythonTool defines its own tool description to override the definition in [`openai-harmony`][harmony].
|
||||
|
||||
> [!WARNING]
|
||||
> This implementation runs in a permissive Docker container which could be problematic in cases like prompt injections. It's serving as an example and you should consider implementing your own container restrictions in production.
|
||||
|
||||
#### Usage
|
||||
|
||||
To enable the browser tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_python()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example:
|
||||
|
||||
```python
|
||||
import datetime
|
||||
from gpt_oss.tools.python_docker.docker_tool import PythonTool
|
||||
from openai_harmony import SystemContent, Message, Conversation, Role, load_harmony_encoding, HarmonyEncodingName
|
||||
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
python_tool = PythonTool()
|
||||
|
||||
# create a basic system prompt
|
||||
system_message_content = SystemContent.new().with_conversation_start_date(
|
||||
datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
)
|
||||
|
||||
# if you want to use the python tool
|
||||
if use_python_tool:
|
||||
# enables the tool making sure that the prompt gets set with the stateless tool description
|
||||
system_message_content = system_message_content.with_tools(python_tool.tool_config)
|
||||
# alternatively you could use the following if your tool is not stateless
|
||||
system_message_content = system_message_content.with_python()
|
||||
|
||||
# construct the system message
|
||||
system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)
|
||||
|
||||
# create the overall prompt
|
||||
messages = [system_message, Message.from_role_and_content(Role.USER, "What's the squareroot of 9001?")]
|
||||
conversation = Conversation.from_messages(messages)
|
||||
|
||||
# convert to tokens
|
||||
token_ids = encoding.render_conversation_for_completion(conversation, Role.ASSISTANT)
|
||||
|
||||
# perform inference
|
||||
# ...
|
||||
|
||||
# parse the output
|
||||
messages = messages = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT)
|
||||
last_message = messages[-1]
|
||||
if last_message.recipient == "python":
|
||||
# perform python call
|
||||
response_messages = await python_tool.process(last_message)
|
||||
|
||||
# extend the current messages and run inference again
|
||||
messages.extend(response_messages)
|
||||
```
|
||||
|
||||
### Apply Patch
|
||||
|
||||
`apply_patch` can be used to create, update or delete files locally.
|
||||
|
||||
## Other details
|
||||
|
||||
### Precision format
|
||||
|
||||
We released the models with native quantization support. Specifically, we use [MXFP4](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) for the linear projection weights in the MoE layer. We store the MoE tensor in two parts:
|
||||
|
||||
- `tensor.blocks` stores the actual fp4 values. We pack every two value in one `uint8` value.
|
||||
- `tensor.scales` stores the block scale. The block scaling is done among the last dimension for all MXFP4 tensors.
|
||||
|
||||
All other tensors will be in BF16. We also recommend use BF16 as the activation precision for the model.
|
||||
|
||||
### Recommended Sampling Parameters
|
||||
|
||||
We recommend sampling with `temperature=1.0` and `top_p=1.0`.
|
||||
|
||||
## Contributing
|
||||
|
||||
The reference implementations in this repository are meant as a starting point and inspiration. Outside of bug fixes we do not intend to accept new feature contributions. If you build implementations based on this code such as new tool implementations you are welcome to contribute them to the [`awesome-gpt-oss.md`](./awesome-gpt-oss.md) file.
|
||||
|
||||
[harmony]: https://github.com/openai/harmony
|
||||
70
awesome-gpt-oss.md
Normal file
70
awesome-gpt-oss.md
Normal file
@@ -0,0 +1,70 @@
|
||||

|
||||
|
||||
# Awesome gpt-oss
|
||||
|
||||
This is a list of guides and resources to help you get started with the gpt-oss models.
|
||||
|
||||
- [Inference](#inference)
|
||||
- [Local](#local)
|
||||
- [Server](#server)
|
||||
- [Cloud](#cloud)
|
||||
- [Examples / Tutorials](#examples--tutorials)
|
||||
- [Tools](#tools)
|
||||
|
||||
## Inference
|
||||
|
||||
### Local
|
||||
|
||||
- Ollama
|
||||
- [How to run gpt-oss locally with Ollama](https://cookbook.openai.com/articles/gpt-oss/run-locally-ollama)
|
||||
- [Ollama & gpt-oss launch blog](https://ollama.com/blog/gpt-oss)
|
||||
- [Check out the models Ollama](https://ollama.com/library/gpt-oss)
|
||||
- LM Studio
|
||||
- [LM Studio & gpt-oss launch blog](https://lmstudio.ai/blog/gpt-oss)
|
||||
- [Use gpt-oss-20b with LM Studio](https://lmstudio.ai/models/openai/gpt-oss-20b)
|
||||
- [Use gpt-oss-120b with LM Studio](https://lmstudio.ai/models/openai/gpt-oss-120b)
|
||||
- Hugging Face & Transformers
|
||||
- [How to run gpt-oss with Transformers](https://cookbook.openai.com/articles/gpt-oss/run-transformers)
|
||||
- [Hugging Face & gpt-oss launch blog](http://huggingface.co/blog/welcome-openai-gpt-oss)
|
||||
- [Collection of Hugging Face examples](https://github.com/huggingface/gpt-oss-recipes)
|
||||
- NVIDIA
|
||||
- [gpt-oss on RTX](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss)
|
||||
|
||||
### Server
|
||||
|
||||
- vLLM
|
||||
- [How to run gpt-oss with vLLM](https://cookbook.openai.com/articles/gpt-oss/run-vllm)
|
||||
- NVIDIA
|
||||
- [Optimizing gpt-oss with NVIDIA TensorRT-LLM](https://cookbook.openai.com/articles/gpt-oss/run-nvidia)
|
||||
- [Deploying gpt-oss on TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog_9_Deploying_GPT_OSS_on_TRTLLM.md)
|
||||
|
||||
### Cloud
|
||||
|
||||
- Groq
|
||||
- [Groq & gpt-oss launch blog](http://groq.com/day-zero-support-for-openai-open-model)
|
||||
- [gpt-oss-120b model on the GroqCloud Playground](https://console.groq.com/playground?model=openai/gpt-oss-120b)
|
||||
- [gpt-oss-20b model on the GroqCloud Playground](https://console.groq.com/playground?model=openai/gpt-oss-20b)
|
||||
- [gpt-oss with built-in web search on GroqCloud](https://console.groq.com/docs/browser-search)
|
||||
- [gpt-oss with built-in code execution on GroqCloud](https://console.groq.com/docs/code-execution)
|
||||
- [Responses API on Groq](https://console.groq.com/docs/responses)
|
||||
- NVIDIA
|
||||
- [NVIDIA launch blog post](https://blogs.nvidia.com/blog/openai-gpt-oss/)
|
||||
- [NVIDIA & gpt-oss developer launch blog post](https://developer.nvidia.com/blog/delivering-1-5-m-tps-inference-on-nvidia-gb200-nvl72-nvidia-accelerates-openai-gpt-oss-models-from-cloud-to-edge/)
|
||||
- Use [gpt-oss-120b](https://build.nvidia.com/openai/gpt-oss-120b) and [gpt-oss-20b](https://build.nvidia.com/openai/gpt-oss-20b) on NVIDIA's Cloud
|
||||
- Cloudflare
|
||||
- [Cloudflare & gpt-oss launch blog post](http://blog.cloudflare.com/openai-gpt-oss-on-workers-ai)
|
||||
- [gpt-oss-120b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-120b)
|
||||
- [gpt-oss-20b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-20b)
|
||||
|
||||
## Examples & Tutorials
|
||||
|
||||
- [OpenAI harmony response format](https://cookbook.openai.com/articles/openai-harmony)
|
||||
|
||||
## Tools
|
||||
|
||||
- [Example `python` tool for gpt-oss](./gpt_oss/tools/python_docker/)
|
||||
- [Example `browser` tool for gpt-oss](./gpt_oss/tools/simple_browser/)
|
||||
|
||||
## Contributing
|
||||
|
||||
Feel free to open a PR to add your own guides and resources on how to run gpt-oss. We will try to review it and add it here.
|
||||
27
docs/gpt-oss-120b.svg
Normal file
27
docs/gpt-oss-120b.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 14 MiB |
27
docs/gpt-oss-20b.svg
Normal file
27
docs/gpt-oss-20b.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 14 MiB |
27
docs/gpt-oss.svg
Normal file
27
docs/gpt-oss.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 14 MiB |
90
examples/agents-sdk-js/index.ts
Normal file
90
examples/agents-sdk-js/index.ts
Normal file
@@ -0,0 +1,90 @@
|
||||
import { OpenAI } from "openai";
|
||||
import {
|
||||
Agent,
|
||||
run,
|
||||
setDefaultOpenAIClient,
|
||||
setOpenAIAPI,
|
||||
setTracingDisabled,
|
||||
tool,
|
||||
MCPServerStdio,
|
||||
} from "@openai/agents";
|
||||
import { z } from "zod";
|
||||
import path from "node:path";
|
||||
import process from "node:process";
|
||||
import { styleText } from "node:util";
|
||||
import { createInterface } from "node:readline/promises";
|
||||
|
||||
async function prompt(question: string) {
|
||||
const rl = createInterface({
|
||||
input: process.stdin,
|
||||
output: process.stdout,
|
||||
});
|
||||
const answer = await rl.question(question);
|
||||
rl.close();
|
||||
return answer;
|
||||
}
|
||||
|
||||
const openai = new OpenAI({
|
||||
apiKey: "local",
|
||||
baseURL: "http://localhost:11434/v1",
|
||||
});
|
||||
|
||||
const samplesDir = path.join(process.cwd());
|
||||
|
||||
const mcpServer = new MCPServerStdio({
|
||||
name: "Filesystem MCP Server, via npx",
|
||||
fullCommand: `npx -y @modelcontextprotocol/server-filesystem ${samplesDir}`,
|
||||
});
|
||||
|
||||
await mcpServer.connect();
|
||||
|
||||
setTracingDisabled(true);
|
||||
setDefaultOpenAIClient(openai);
|
||||
setOpenAIAPI("chat_completions");
|
||||
|
||||
const searchTool = tool({
|
||||
name: "get_current_weather",
|
||||
description: "Get the current weather in a given location",
|
||||
parameters: z.object({
|
||||
location: z.string(),
|
||||
}),
|
||||
execute: async ({ location }) => {
|
||||
return `The weather in ${location} is sunny.`;
|
||||
},
|
||||
});
|
||||
|
||||
const agent = new Agent({
|
||||
name: "My Agent",
|
||||
instructions: "You are a helpful assistant.",
|
||||
tools: [searchTool],
|
||||
model: "gpt-oss:20b-test",
|
||||
mcpServers: [mcpServer],
|
||||
});
|
||||
|
||||
const input = await prompt("> ");
|
||||
|
||||
const result = await run(agent, input, {
|
||||
stream: true,
|
||||
});
|
||||
|
||||
for await (const event of result) {
|
||||
if (event.type === "raw_model_stream_event" && event.data.type === "model") {
|
||||
if (event.data.event.choices[0].delta.content) {
|
||||
process.stdout.write(event.data.event.choices[0].delta.content);
|
||||
} else if (event.data.event.choices[0].delta.reasoning) {
|
||||
process.stdout.write(event.data.event.choices[0].delta.reasoning);
|
||||
}
|
||||
} else if (
|
||||
event.type === "run_item_stream_event" &&
|
||||
event.item.type === "tool_call_item" &&
|
||||
event.item.rawItem.type == "function_call"
|
||||
) {
|
||||
console.log(
|
||||
`\nCalling ${event.item.rawItem.name} with: ${event.item.rawItem.arguments}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
console.log("\n");
|
||||
await result.completed;
|
||||
await mcpServer.close();
|
||||
1798
examples/agents-sdk-js/package-lock.json
generated
Normal file
1798
examples/agents-sdk-js/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
20
examples/agents-sdk-js/package.json
Normal file
20
examples/agents-sdk-js/package.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"type": "module",
|
||||
"name": "agents-sdk",
|
||||
"version": "1.0.0",
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"start": "tsx index.ts",
|
||||
"test": "echo \"Error: no test specified\" && exit 1"
|
||||
},
|
||||
"keywords": [],
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
"description": "",
|
||||
"dependencies": {
|
||||
"@openai/agents": "^0.0.14",
|
||||
"tsx": "^4.20.3",
|
||||
"typescript": "^5.8.3",
|
||||
"zod": "^3.25.67"
|
||||
}
|
||||
}
|
||||
252
examples/streamlit/streamlit_chat.py
Normal file
252
examples/streamlit/streamlit_chat.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
import streamlit as st
|
||||
|
||||
DEFAULT_FUNCTION_PROPERTIES = """
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
""".strip()
|
||||
|
||||
# Session state for chat
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
|
||||
st.title("💬 Chatbot")
|
||||
|
||||
if "model" not in st.session_state:
|
||||
if "model" in st.query_params:
|
||||
st.session_state.model = st.query_params["model"]
|
||||
else:
|
||||
st.session_state.model = "small"
|
||||
|
||||
options = ["large", "small"]
|
||||
selection = st.sidebar.segmented_control(
|
||||
"Model", options, selection_mode="single", default=st.session_state.model
|
||||
)
|
||||
# st.session_state.model = selection
|
||||
st.query_params.update({"model": selection})
|
||||
|
||||
instructions = st.sidebar.text_area(
|
||||
"Instructions",
|
||||
value="You are a helpful assistant that can answer questions and help with tasks.",
|
||||
)
|
||||
effort = st.sidebar.radio(
|
||||
"Reasoning effort",
|
||||
["low", "medium", "high"],
|
||||
index=1,
|
||||
)
|
||||
st.sidebar.divider()
|
||||
st.sidebar.subheader("Functions")
|
||||
use_functions = st.sidebar.toggle("Use functions", value=False)
|
||||
|
||||
if "show_browser" in st.query_params:
|
||||
st.sidebar.subheader("Built-in Tools")
|
||||
# Built-in Tools section
|
||||
use_browser_search = st.sidebar.toggle("Use browser search", value=False)
|
||||
else:
|
||||
use_browser_search = False
|
||||
|
||||
if use_functions:
|
||||
function_name = st.sidebar.text_input("Function name", value="get_weather")
|
||||
function_description = st.sidebar.text_area(
|
||||
"Function description", value="Get the weather for a given city"
|
||||
)
|
||||
function_parameters = st.sidebar.text_area(
|
||||
"Function parameters", value=DEFAULT_FUNCTION_PROPERTIES
|
||||
)
|
||||
else:
|
||||
function_name = None
|
||||
function_description = None
|
||||
function_parameters = None
|
||||
st.sidebar.divider()
|
||||
temperature = st.sidebar.slider(
|
||||
"Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.01
|
||||
)
|
||||
max_output_tokens = st.sidebar.slider(
|
||||
"Max output tokens", min_value=1000, max_value=20000, value=1024, step=100
|
||||
)
|
||||
st.sidebar.divider()
|
||||
debug_mode = st.sidebar.toggle("Debug mode", value=False)
|
||||
|
||||
if debug_mode:
|
||||
st.sidebar.divider()
|
||||
st.sidebar.code(json.dumps(st.session_state.messages, indent=2), "json")
|
||||
|
||||
render_input = True
|
||||
|
||||
URL = (
|
||||
"http://localhost:8081/v1/responses"
|
||||
if selection == options[1]
|
||||
else "http://localhost:8000/v1/responses"
|
||||
)
|
||||
|
||||
def trigger_fake_tool(container):
|
||||
function_output = st.session_state.get("function_output", "It's sunny!")
|
||||
last_call = st.session_state.messages[-1]
|
||||
if last_call.get("type") == "function_call":
|
||||
st.session_state.messages.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": last_call.get("call_id"),
|
||||
"output": function_output,
|
||||
}
|
||||
)
|
||||
run(container)
|
||||
|
||||
|
||||
def run(container):
|
||||
tools = []
|
||||
if use_functions:
|
||||
tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": function_name,
|
||||
"description": function_description,
|
||||
"parameters": json.loads(function_parameters),
|
||||
}
|
||||
)
|
||||
# Add browser_search tool if checkbox is checked
|
||||
if use_browser_search:
|
||||
tools.append({"type": "browser_search"})
|
||||
response = requests.post(
|
||||
URL,
|
||||
json={
|
||||
"input": st.session_state.messages,
|
||||
"stream": True,
|
||||
"instructions": instructions,
|
||||
"reasoning": {"effort": effort},
|
||||
"metadata": {"__debug": debug_mode},
|
||||
"tools": tools,
|
||||
"temperature": temperature,
|
||||
"max_output_tokens": max_output_tokens,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
text_delta = ""
|
||||
|
||||
current_output_index = 0
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if not line or not line.startswith("data:"):
|
||||
continue
|
||||
data_str = line[len("data:") :].strip()
|
||||
if not data_str:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
event_type = data.get("type", "")
|
||||
output_index = data.get("output_index", 0)
|
||||
if event_type == "response.output_item.added":
|
||||
current_output_index = output_index
|
||||
output_type = data.get("item", {}).get("type", "message")
|
||||
if output_type == "message":
|
||||
output = container.chat_message("assistant")
|
||||
placeholder = output.empty()
|
||||
elif output_type == "reasoning":
|
||||
output = container.chat_message("reasoning", avatar="🤔")
|
||||
placeholder = output.empty()
|
||||
elif output_type == "web_search_call":
|
||||
output = container.chat_message("web_search_call", avatar="🌐")
|
||||
output.code(json.dumps(data.get("item", {}).get("action", {}), indent=4), language="json")
|
||||
placeholder = output.empty()
|
||||
text_delta = ""
|
||||
elif event_type == "response.reasoning_text.delta":
|
||||
output.avatar = "🤔"
|
||||
text_delta += data.get("delta", "")
|
||||
placeholder.markdown(text_delta)
|
||||
elif event_type == "response.output_text.delta":
|
||||
text_delta += data.get("delta", "")
|
||||
placeholder.markdown(text_delta)
|
||||
elif event_type == "response.output_item.done":
|
||||
item = data.get("item", {})
|
||||
if item.get("type") == "function_call":
|
||||
with container.chat_message("function_call", avatar="🔨"):
|
||||
st.markdown(f"Called `{item.get("name")}`")
|
||||
st.caption("Arguments")
|
||||
st.code(item.get("arguments", ""), language="json")
|
||||
if item.get("type") == "web_search_call":
|
||||
placeholder.markdown("✅ Done")
|
||||
elif event_type == "response.completed":
|
||||
response = data.get("response", {})
|
||||
if debug_mode:
|
||||
container.expander("Debug", expanded=False).code(
|
||||
response.get("metadata", {}).get("__debug", ""), language="text"
|
||||
)
|
||||
st.session_state.messages.extend(response.get("output", []))
|
||||
if st.session_state.messages[-1].get("type") == "function_call":
|
||||
with container.form("function_output_form"):
|
||||
function_output = st.text_input(
|
||||
"Enter function output",
|
||||
value=st.session_state.get("function_output", "It's sunny!"),
|
||||
key="function_output",
|
||||
)
|
||||
st.form_submit_button(
|
||||
"Submit function output",
|
||||
on_click=trigger_fake_tool,
|
||||
args=[container],
|
||||
)
|
||||
# Optionally handle other event types...
|
||||
|
||||
|
||||
# Chat display
|
||||
for msg in st.session_state.messages:
|
||||
if msg.get("type") == "message":
|
||||
with st.chat_message(msg["role"]):
|
||||
for item in msg["content"]:
|
||||
if (
|
||||
item.get("type") == "text"
|
||||
or item.get("type") == "output_text"
|
||||
or item.get("type") == "input_text"
|
||||
):
|
||||
st.markdown(item["text"])
|
||||
if item.get("annotations"):
|
||||
annotation_lines = "\n".join(
|
||||
f"- {annotation.get('url')}" for annotation in item["annotations"] if annotation.get("url")
|
||||
)
|
||||
st.caption(f"**Annotations:**\n{annotation_lines}")
|
||||
elif msg.get("type") == "reasoning":
|
||||
with st.chat_message("reasoning", avatar="🤔"):
|
||||
for item in msg["content"]:
|
||||
if item.get("type") == "reasoning_text":
|
||||
st.markdown(item["text"])
|
||||
elif msg.get("type") == "function_call":
|
||||
with st.chat_message("function_call", avatar="🔨"):
|
||||
st.markdown(f"Called `{msg.get("name")}`")
|
||||
st.caption("Arguments")
|
||||
st.code(msg.get("arguments", ""), language="json")
|
||||
elif msg.get("type") == "function_call_output":
|
||||
with st.chat_message("function_call_output", avatar="✅"):
|
||||
st.caption("Output")
|
||||
st.code(msg.get("output", ""), language="text")
|
||||
elif msg.get("type") == "web_search_call":
|
||||
with st.chat_message("web_search_call", avatar="🌐"):
|
||||
st.code(json.dumps(msg.get("action", {}), indent=4), language="json")
|
||||
st.markdown("✅ Done")
|
||||
|
||||
if render_input:
|
||||
# Input field
|
||||
if prompt := st.chat_input("Type a message..."):
|
||||
st.session_state.messages.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": prompt}],
|
||||
}
|
||||
)
|
||||
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
run(st.container())
|
||||
0
gpt_oss/__init__.py
Normal file
0
gpt_oss/__init__.py
Normal file
367
gpt_oss/chat.py
Normal file
367
gpt_oss/chat.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Harmony chat with tools
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import argparse
|
||||
import asyncio
|
||||
import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import gnureadline as readline
|
||||
except ImportError:
|
||||
import readline
|
||||
|
||||
import torch
|
||||
import termcolor
|
||||
|
||||
from gpt_oss.tools import apply_patch
|
||||
from gpt_oss.tools.simple_browser import SimpleBrowserTool
|
||||
from gpt_oss.tools.simple_browser.backend import ExaBackend
|
||||
from gpt_oss.tools.python_docker.docker_tool import PythonTool
|
||||
|
||||
from openai_harmony import (
|
||||
Author,
|
||||
Conversation,
|
||||
DeveloperContent,
|
||||
HarmonyEncodingName,
|
||||
Message,
|
||||
ReasoningEffort,
|
||||
Role,
|
||||
StreamableParser,
|
||||
StreamState,
|
||||
SystemContent,
|
||||
TextContent,
|
||||
ToolDescription,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
|
||||
REASONING_EFFORT = {
|
||||
"high": ReasoningEffort.HIGH,
|
||||
"medium": ReasoningEffort.MEDIUM,
|
||||
"low": ReasoningEffort.LOW,
|
||||
}
|
||||
|
||||
|
||||
def get_user_input():
|
||||
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
||||
if rank == 0:
|
||||
user_input = input()
|
||||
else:
|
||||
user_input = ""
|
||||
user_input_list = [user_input]
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.broadcast_object_list(user_input_list, 0)
|
||||
return user_input_list[0]
|
||||
|
||||
|
||||
def main(args):
|
||||
match args.backend:
|
||||
case "triton":
|
||||
from gpt_oss.triton.model import TokenGenerator as TritonGenerator
|
||||
from gpt_oss.torch.utils import init_distributed
|
||||
device = init_distributed()
|
||||
generator = TritonGenerator(args.checkpoint, args.context, device)
|
||||
case "torch":
|
||||
from gpt_oss.torch.model import TokenGenerator as TorchGenerator
|
||||
from gpt_oss.torch.utils import init_distributed
|
||||
device = init_distributed()
|
||||
generator = TorchGenerator(args.checkpoint, device)
|
||||
case "vllm":
|
||||
from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator
|
||||
generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2)
|
||||
case _:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
system_message_content = (
|
||||
SystemContent.new()
|
||||
.with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort])
|
||||
.with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d"))
|
||||
)
|
||||
|
||||
if args.browser:
|
||||
backend = ExaBackend(
|
||||
source="web",
|
||||
)
|
||||
browser_tool = SimpleBrowserTool(backend=backend)
|
||||
system_message_content = system_message_content.with_tools(browser_tool.tool_config)
|
||||
|
||||
if args.python:
|
||||
python_tool = PythonTool()
|
||||
system_message_content = system_message_content.with_tools(python_tool.tool_config)
|
||||
|
||||
system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)
|
||||
messages = [system_message]
|
||||
|
||||
if args.apply_patch:
|
||||
apply_patch_instructions = Path(apply_patch.__file__).parent / "apply_patch.md"
|
||||
developer_message = ""
|
||||
if args.developer_message:
|
||||
developer_message = args.developer_message + "\n"
|
||||
developer_message += apply_patch_instructions.read_text()
|
||||
developer_message_content = (
|
||||
DeveloperContent.new()
|
||||
.with_instructions(developer_message)
|
||||
.with_function_tools([
|
||||
ToolDescription.new(
|
||||
"apply_patch",
|
||||
"Patch a file",
|
||||
parameters={
|
||||
"type": "string",
|
||||
"description": "Formatted patch code",
|
||||
"default": "*** Begin Patch\n*** End Patch\n",
|
||||
}
|
||||
),
|
||||
])
|
||||
)
|
||||
messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content))
|
||||
elif args.developer_message:
|
||||
developer_message_content = DeveloperContent.new().with_instructions(args.developer_message)
|
||||
messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content))
|
||||
|
||||
if args.raw:
|
||||
conversation = Conversation.from_messages(messages)
|
||||
tokens = encoding.render_conversation(conversation)
|
||||
system_message = encoding.decode(tokens)
|
||||
print(system_message, flush=True, end="")
|
||||
empty_user_message_tokens = encoding.render(Message.from_role_and_content(Role.USER, ""))
|
||||
user_message_start = encoding.decode(empty_user_message_tokens[:-1])
|
||||
user_message_end = encoding.decode(empty_user_message_tokens[-1:])
|
||||
else:
|
||||
# System message
|
||||
print(termcolor.colored("System Message:", "cyan"), flush=True)
|
||||
print(termcolor.colored("Model Identity:", "cyan"), system_message_content.model_identity, flush=True)
|
||||
print(termcolor.colored("Reasoning Effort:", "cyan"), system_message_content.reasoning_effort, flush=True)
|
||||
print(termcolor.colored("Conversation Start Date:", "cyan"), system_message_content.conversation_start_date, flush=True)
|
||||
print(termcolor.colored("Knowledge Cutoff:", "cyan"), system_message_content.knowledge_cutoff, flush=True)
|
||||
print(termcolor.colored("Browser Tool:", "cyan"), "Enabled" if args.browser else "Disabled", flush=True)
|
||||
print(termcolor.colored("Python Tool:", "cyan"), "Enabled" if args.python else "Disabled", flush=True)
|
||||
print(termcolor.colored("Apply Patch Function:", "cyan"), "Enabled" if args.apply_patch else "Disabled", flush=True)
|
||||
# Developer message
|
||||
print(termcolor.colored("Developer Message:", "yellow"), flush=True)
|
||||
print(developer_message_content.instructions, flush=True)
|
||||
|
||||
# Print the system message and the user message start
|
||||
MESSAGE_PADDING = 12
|
||||
while True:
|
||||
last_message = messages[-1]
|
||||
if last_message.recipient is None:
|
||||
if args.raw:
|
||||
print(user_message_start, end="", flush=True)
|
||||
user_message = get_user_input()
|
||||
print(user_message_end, flush=True, end="")
|
||||
else:
|
||||
print(termcolor.colored("User:".ljust(MESSAGE_PADDING), "red"), flush=True)
|
||||
user_message = get_user_input()
|
||||
user_message = Message.from_role_and_content(Role.USER, user_message)
|
||||
messages.append(user_message)
|
||||
else:
|
||||
# Tool or function call
|
||||
if last_message.recipient.startswith("browser."):
|
||||
assert args.browser, "Browser tool is not enabled"
|
||||
tool_name = "Search"
|
||||
async def run_tool():
|
||||
results = []
|
||||
async for msg in browser_tool.process(last_message):
|
||||
results.append(msg)
|
||||
return results
|
||||
|
||||
result = asyncio.run(run_tool())
|
||||
messages += result
|
||||
elif last_message.recipient.startswith("python"):
|
||||
assert args.python, "Python tool is not enabled"
|
||||
tool_name = "Python"
|
||||
async def run_tool():
|
||||
results = []
|
||||
async for msg in python_tool.process(last_message):
|
||||
results.append(msg)
|
||||
return results
|
||||
|
||||
result = asyncio.run(run_tool())
|
||||
messages += result
|
||||
elif last_message.recipient == "functions.apply_patch":
|
||||
assert args.apply_patch, "Apply patch tool is not enabled"
|
||||
tool_name = "Apply Patch"
|
||||
text = last_message.content[0].text
|
||||
tool_output = None
|
||||
|
||||
if text.startswith("{"):
|
||||
# this is json, try to extract the patch from it
|
||||
import json
|
||||
try:
|
||||
some_dict = json.loads(text)
|
||||
_, text = some_dict.popitem()
|
||||
except Exception as e:
|
||||
tool_output = f"Error parsing JSON: {e}"
|
||||
|
||||
if tool_output is None:
|
||||
try:
|
||||
tool_output = apply_patch.apply_patch(text)
|
||||
except Exception as e:
|
||||
tool_output = f"Error applying patch: {e}"
|
||||
|
||||
message = (
|
||||
Message(
|
||||
author=Author.new(Role.TOOL, last_message.recipient),
|
||||
content=[TextContent(text=tool_output)]
|
||||
)
|
||||
.with_recipient("assistant")
|
||||
)
|
||||
if last_message.channel:
|
||||
message = message.with_channel(last_message.channel)
|
||||
|
||||
result = [message]
|
||||
messages += result
|
||||
else:
|
||||
raise ValueError(f"Unknown tool or function call: {last_message.recipient}")
|
||||
# Print the tool or function call result
|
||||
if args.raw:
|
||||
rendered_result = encoding.render_conversation(Conversation.from_messages(result))
|
||||
print(encoding.decode(rendered_result), flush=True, end="")
|
||||
else:
|
||||
print(termcolor.colored(f"{tool_name} output:".ljust(MESSAGE_PADDING), "magenta"), flush=True)
|
||||
if tool_name == "Search" and not args.show_browser_results:
|
||||
print("[Search results fed to the model]")
|
||||
else:
|
||||
print(result[0].content[0].text)
|
||||
|
||||
conversation = Conversation.from_messages(messages)
|
||||
tokens = encoding.render_conversation_for_completion(
|
||||
conversation, Role.ASSISTANT
|
||||
)
|
||||
|
||||
if args.raw:
|
||||
# Print the last two tokens, which are the start of the assistant message
|
||||
print(encoding.decode(tokens[-2:]), flush=True, end="")
|
||||
|
||||
parser = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
field_created = False
|
||||
current_output_text = ""
|
||||
output_text_delta_buffer = ""
|
||||
for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions()):
|
||||
parser.process(predicted_token)
|
||||
if args.raw:
|
||||
print(encoding.decode([predicted_token]), end="", flush=True)
|
||||
continue
|
||||
|
||||
if parser.state == StreamState.EXPECT_START:
|
||||
print("") # new line
|
||||
field_created = False
|
||||
|
||||
if not parser.last_content_delta:
|
||||
continue
|
||||
|
||||
if not field_created:
|
||||
field_created = True
|
||||
if parser.current_channel == "final":
|
||||
print(termcolor.colored("Assistant:", "green"), flush=True)
|
||||
elif parser.current_recipient is not None:
|
||||
print(termcolor.colored(f"Tool call to {parser.current_recipient}:", "cyan"), flush=True)
|
||||
else:
|
||||
print(termcolor.colored("CoT:", "yellow"), flush=True)
|
||||
|
||||
should_send_output_text_delta = True
|
||||
output_text_delta_buffer += parser.last_content_delta
|
||||
if args.browser:
|
||||
updated_output_text, _annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text + output_text_delta_buffer)
|
||||
output_text_delta_buffer = updated_output_text[len(current_output_text):]
|
||||
if has_partial_citations:
|
||||
should_send_output_text_delta = False
|
||||
if should_send_output_text_delta:
|
||||
print(output_text_delta_buffer, end="", flush=True)
|
||||
current_output_text += output_text_delta_buffer
|
||||
output_text_delta_buffer = ""
|
||||
|
||||
messages += parser.messages
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Chat example",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"checkpoint",
|
||||
metavar="FILE",
|
||||
type=str,
|
||||
help="Path to the SafeTensors checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--reasoning-effort",
|
||||
metavar="REASONING_EFFORT",
|
||||
type=str,
|
||||
default="low",
|
||||
choices=["high", "medium", "low"],
|
||||
help="Reasoning effort",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--apply-patch",
|
||||
action="store_true",
|
||||
help="Make apply_patch function available to the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--browser",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use browser tool",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-browser-results",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Show browser results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--python",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use python tool",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--developer-message",
|
||||
default="",
|
||||
help="Developer message",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--context",
|
||||
metavar="CONTEXT",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="Max context length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raw",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Raw mode (does not render Harmony encoding)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="triton",
|
||||
choices=["triton", "torch", "vllm"],
|
||||
help="Inference backend",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if int(os.environ.get("WORLD_SIZE", 1)) == 1:
|
||||
histfile = os.path.join(os.path.expanduser("~"), ".chat")
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(10000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
||||
main(args)
|
||||
4
gpt_oss/evals/README.md
Normal file
4
gpt_oss/evals/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# `gpt_oss.evals`
|
||||
|
||||
This module is a reincarnation of [simple-evals](https://github.com/openai/simple-evals) adapted for gpt-oss. It lets you
|
||||
run GPQA and HealthBench against a runtime that supports Responses API on `localhost:8080/v1`.
|
||||
0
gpt_oss/evals/__init__.py
Normal file
0
gpt_oss/evals/__init__.py
Normal file
241
gpt_oss/evals/__main__.py
Normal file
241
gpt_oss/evals/__main__.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from . import report
|
||||
from .gpqa_eval import GPQAEval
|
||||
from .aime_eval import AIME25Eval
|
||||
from .healthbench_eval import HealthBenchEval
|
||||
from .chat_completion_sampler import (
|
||||
OPENAI_SYSTEM_MESSAGE_API,
|
||||
ChatCompletionSampler,
|
||||
)
|
||||
from .responses_sampler import ResponsesSampler
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate the models.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-models", action="store_true", help="List available models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="Select a model by name. Also accepts a comma-separated list of models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="Base URL for the API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval",
|
||||
type=str,
|
||||
default="gpqa,healthbench,healthbench_hard,healthbench_consensus,aime25",
|
||||
help="Select an eval by name. Also accepts a comma-separated list of evals.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Sampling temperature",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-threads",
|
||||
type=int,
|
||||
default=1584,
|
||||
help="Number of threads to run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="Run in debug mode"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--examples", type=int, help="Number of examples to use (overrides default)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
models = {
|
||||
"120b-low": ResponsesSampler(
|
||||
model="gpt-oss-120b",
|
||||
reasoning_model=True,
|
||||
reasoning_effort="low",
|
||||
temperature=args.temperature,
|
||||
base_url=args.base_url,
|
||||
),
|
||||
"120b": ResponsesSampler(
|
||||
model="gpt-oss-120b",
|
||||
reasoning_model=True,
|
||||
reasoning_effort="medium",
|
||||
temperature=args.temperature,
|
||||
base_url=args.base_url,
|
||||
),
|
||||
"120b-high": ResponsesSampler(
|
||||
model="gpt-oss-120b",
|
||||
reasoning_model=True,
|
||||
reasoning_effort="high",
|
||||
temperature=args.temperature,
|
||||
base_url=args.base_url,
|
||||
),
|
||||
"20b-low": ResponsesSampler(
|
||||
model="gpt-oss-20b",
|
||||
reasoning_model=True,
|
||||
reasoning_effort="low",
|
||||
temperature=args.temperature,
|
||||
base_url=args.base_url,
|
||||
),
|
||||
"20b": ResponsesSampler(
|
||||
model="gpt-oss-20b",
|
||||
reasoning_model=True,
|
||||
reasoning_effort="medium",
|
||||
temperature=args.temperature,
|
||||
base_url=args.base_url,
|
||||
),
|
||||
"20b-high": ResponsesSampler(
|
||||
model="gpt-oss-20b",
|
||||
reasoning_model=True,
|
||||
reasoning_effort="high",
|
||||
temperature=args.temperature,
|
||||
base_url=args.base_url,
|
||||
),
|
||||
}
|
||||
|
||||
if args.list_models:
|
||||
print("Available models:")
|
||||
for model_name in models.keys():
|
||||
print(f" - {model_name}")
|
||||
return
|
||||
|
||||
if args.model:
|
||||
models_chosen = args.model.split(",")
|
||||
for model_name in models_chosen:
|
||||
if model_name not in models:
|
||||
print(f"Error: Model '{model_name}' not found.")
|
||||
return
|
||||
models = {model_name: models[model_name] for model_name in models_chosen}
|
||||
|
||||
print(f"Running with args {args}")
|
||||
|
||||
grading_sampler = ChatCompletionSampler(
|
||||
model="gpt-4.1-2025-04-14",
|
||||
system_message=OPENAI_SYSTEM_MESSAGE_API,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
def get_evals(eval_name, debug_mode):
|
||||
num_examples = (
|
||||
args.examples if args.examples is not None else (5 if debug_mode else None)
|
||||
)
|
||||
# Set num_examples = None to reproduce full evals
|
||||
match eval_name:
|
||||
case "gpqa":
|
||||
return GPQAEval(
|
||||
n_repeats=8,
|
||||
num_examples=num_examples,
|
||||
debug=debug_mode,
|
||||
n_threads=args.n_threads or 1,
|
||||
)
|
||||
case "healthbench":
|
||||
return HealthBenchEval(
|
||||
grader_model=grading_sampler,
|
||||
num_examples=10 if debug_mode else num_examples,
|
||||
n_repeats=1,
|
||||
n_threads=args.n_threads or 1,
|
||||
subset_name=None,
|
||||
)
|
||||
case "healthbench_hard":
|
||||
return HealthBenchEval(
|
||||
grader_model=grading_sampler,
|
||||
num_examples=10 if debug_mode else num_examples,
|
||||
n_repeats=1,
|
||||
n_threads=args.n_threads or 1,
|
||||
subset_name="hard",
|
||||
)
|
||||
case "healthbench_consensus":
|
||||
return HealthBenchEval(
|
||||
grader_model=grading_sampler,
|
||||
num_examples=10 if debug_mode else num_examples,
|
||||
n_repeats=1,
|
||||
n_threads=args.n_threads or 1,
|
||||
subset_name="consensus",
|
||||
)
|
||||
case "aime25":
|
||||
return AIME25Eval(
|
||||
n_repeats=8,
|
||||
num_examples=num_examples,
|
||||
n_threads=args.n_threads or 1,
|
||||
)
|
||||
case _:
|
||||
raise Exception(f"Unrecognized eval type: {eval_name}")
|
||||
|
||||
evals_list = args.eval.split(",")
|
||||
evals = {}
|
||||
for eval_name in evals_list:
|
||||
evals[eval_name] = get_evals(eval_name, args.debug)
|
||||
|
||||
print(evals)
|
||||
debug_suffix = "_DEBUG" if args.debug else ""
|
||||
print(debug_suffix)
|
||||
mergekey2resultpath = {}
|
||||
print(f"Running the following evals: {list(evals.keys())}")
|
||||
print(f"Running evals for the following models: {list(models.keys())}")
|
||||
|
||||
now = datetime.now()
|
||||
date_str = now.strftime("%Y%m%d_%H%M%S")
|
||||
for model_name, sampler in models.items():
|
||||
for eval_name, eval_obj in evals.items():
|
||||
result = eval_obj(sampler)
|
||||
# ^^^ how to use a sampler
|
||||
file_stem = f"{eval_name}_{model_name}_temp{args.temperature}"
|
||||
# file stem should also include the year, month, day, and time in hours and minutes
|
||||
file_stem += f"_{date_str}"
|
||||
report_filename = f"/tmp/{file_stem}{debug_suffix}.html"
|
||||
print(f"Writing report to {report_filename}")
|
||||
with open(report_filename, "w") as fh:
|
||||
fh.write(report.make_report(result))
|
||||
assert result.metrics is not None
|
||||
metrics = result.metrics | {"score": result.score}
|
||||
# Sort metrics by key
|
||||
metrics = dict(sorted(metrics.items()))
|
||||
print(metrics)
|
||||
result_filename = f"/tmp/{file_stem}{debug_suffix}.json"
|
||||
with open(result_filename, "w") as f:
|
||||
f.write(json.dumps(metrics, indent=2))
|
||||
print(f"Writing results to {result_filename}")
|
||||
|
||||
full_result_filename = f"/tmp/{file_stem}{debug_suffix}_allresults.json"
|
||||
with open(full_result_filename, "w") as f:
|
||||
result_dict = {
|
||||
"score": result.score,
|
||||
"metrics": result.metrics,
|
||||
"htmls": result.htmls,
|
||||
"convos": result.convos,
|
||||
"metadata": result.metadata,
|
||||
}
|
||||
f.write(json.dumps(result_dict, indent=2))
|
||||
print(f"Writing all results to {full_result_filename}")
|
||||
|
||||
mergekey2resultpath[f"{file_stem}"] = result_filename
|
||||
merge_metrics = []
|
||||
for eval_model_name, result_filename in mergekey2resultpath.items():
|
||||
try:
|
||||
result = json.load(open(result_filename, "r+"))
|
||||
except Exception as e:
|
||||
print(e, result_filename)
|
||||
continue
|
||||
result = result.get("f1_score", result.get("score", None))
|
||||
eval_name = eval_model_name[: eval_model_name.find("_")]
|
||||
model_name = eval_model_name[eval_model_name.find("_") + 1 :]
|
||||
merge_metrics.append(
|
||||
{"eval_name": eval_name, "model_name": model_name, "metric": result}
|
||||
)
|
||||
print(merge_metrics)
|
||||
return merge_metrics
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
121
gpt_oss/evals/abcd_grader.py
Normal file
121
gpt_oss/evals/abcd_grader.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
_PATTERNS = [
|
||||
# 0)"**Answer:** A" or "*Answers* – B", i.e. markdown‐wrapped "Answer(s)" with an unwrapped letter.
|
||||
re.compile(
|
||||
r'''(?ix) # case‐insensitive, ignore‐space
|
||||
(?:\*{1,2}|_{1,2}) # leading *…* or _…_
|
||||
Answer[s]? # Answer or Answers
|
||||
\s*[:\-–]? # optional separator
|
||||
(?:\*{1,2}|_{1,2}) # closing wrapper
|
||||
\s* # optional space
|
||||
([ABCD])\b # the actual letter
|
||||
''',
|
||||
re.X
|
||||
),
|
||||
|
||||
# 0.1)
|
||||
re.compile(r'''(?ix) # ignore case, allow verbose mode
|
||||
^\s* # optional leading whitespace
|
||||
(?:\*{1,2}|_{1,2})? # optional markdown wrapper
|
||||
Answer:? # the word 'answer' with an optional colon
|
||||
(?:\*{1,2}|_{1,2})? # optional markdown wrapper again
|
||||
\s*:?\s* # optional colon with optional spaces
|
||||
(?:\*{1,2}|_{1,2})? # optional markdown wrapper before letter
|
||||
([ABCD]) # capture the letter
|
||||
(?:\*{1,2}|_{1,2})? # optional markdown wrapper after letter
|
||||
\s* # optional trailing whitespace, end of line
|
||||
''', re.MULTILINE),
|
||||
|
||||
# 1) Answer: (C) or Answers: (B)
|
||||
re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*\(\s*([ABCD])\s*\)'),
|
||||
|
||||
# 2) Answer: C or Answers – D
|
||||
re.compile(r'(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*([ABCD])\b'),
|
||||
|
||||
# 3) Option B or Choice: C
|
||||
re.compile(r'(?ix)\b(?:Option|Choice)\b\s*[:\-–]?\s*([ABCD])\b'),
|
||||
|
||||
# 7) LaTeX \boxed{...A...}, catches both \boxed{A} and
|
||||
# \boxed{\text{A } 2.08\times10^{-6}\,\mathrm{m}} etc.
|
||||
re.compile(r'(?x)\\boxed\{[^}]*?([ABCD])[^}]*\}', re.MULTILINE),
|
||||
|
||||
# 7.5) LaTeX \boxed{\textbf{...C...}}
|
||||
re.compile(r'(?x)\\boxed\{[^}]*?\\textbf\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE),
|
||||
|
||||
# 7.51) LaTeX \boxed{\text{...C...}}
|
||||
re.compile(r'(?x)\\boxed\{[^}]*?\\text\{[^}]*?([ABCD])[^}]*\}[^}]*\}', re.MULTILINE),
|
||||
|
||||
# 4) bare singletons: (A) [B]
|
||||
re.compile(r'(?x)(?<![A-Za-z0-9])[\(\[]\s*([ABCD])\s*[\)\]](?![A-Za-z0-9])'),
|
||||
|
||||
# 5) Markdown‐wrapped: *A* **B** _C_ __D__
|
||||
re.compile(r'(?x)(?<![A-Za-z0-9])(?:\*{1,2}|_{1,2})([ABCD])(?:\*{1,2}|_{1,2})(?![A-Za-z0-9])'),
|
||||
|
||||
# 6) LaTeX \textbf{...C...}
|
||||
re.compile(r'(?x)\\textbf\{[^}]*?([ABCD])[^}]*\}'),
|
||||
|
||||
# 8) markdown‐wrapped answer plus “)” plus description, e.g. **D) …**
|
||||
re.compile(r'''(?x) # ignore whitespace in pattern
|
||||
(?<![A-Za-z0-9]) # not preceded by word‐char
|
||||
(?:\*{1,2}|_{1,2}) # opening ** or __ or * or _
|
||||
\s*([ABCD])\) # capture letter plus “)”
|
||||
[^*_\n]+? # some text inside wrapper
|
||||
(?:\*{1,2}|_{1,2}) # closing wrapper
|
||||
(?![A-Za-z0-9]) # not followed by word‐char
|
||||
'''),
|
||||
|
||||
# 9) final fallback: a line that's exactly "A", "B.", "C)", "**D**", etc.
|
||||
re.compile(r'''(?x)^\s*
|
||||
(?:\*{1,2}|_{1,2})? # optional markdown wrapper
|
||||
([ABCD]) # capture group for letter
|
||||
(?:\*{1,2}|_{1,2})? # optional closing markdown
|
||||
\s*[\.\)\-–:]? # optional separator after the letter
|
||||
\s*.*$ # allow any following text
|
||||
''', re.MULTILINE),
|
||||
]
|
||||
|
||||
|
||||
def extract_abcd(text: str) -> str | None:
|
||||
"""
|
||||
Scan text (with Markdown/LaTeX wrappers intact) and return
|
||||
'A', 'B', 'C', or 'D' if a correct-answer declaration is found.
|
||||
Otherwise return None.
|
||||
"""
|
||||
matches = []
|
||||
for prio, pat in enumerate(_PATTERNS):
|
||||
m = pat.search(text)
|
||||
if m:
|
||||
letter = m.group(1).upper()
|
||||
if letter in 'ABCD':
|
||||
matches.append((prio, m, letter))
|
||||
|
||||
matches.sort(key=lambda triple: (
|
||||
triple[0],
|
||||
len(triple[1].group(0))
|
||||
))
|
||||
for _, match, letter in matches:
|
||||
return letter
|
||||
return text.removeprefix('**')[:1]
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) > 1:
|
||||
# Process files
|
||||
for fn in sys.argv[1:]:
|
||||
with open(fn, encoding='utf8') as fp:
|
||||
text = fp.read()
|
||||
ans = extract_abcd(text)
|
||||
print(f"{fn} ➜ {ans!r}")
|
||||
else:
|
||||
# Read from stdin
|
||||
for line in sys.stdin:
|
||||
ans = extract_abcd(line)
|
||||
print(f"{line} ➜ {ans!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
97
gpt_oss/evals/aime_eval.py
Normal file
97
gpt_oss/evals/aime_eval.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
AIME 2025: https://huggingface.co/datasets/opencompass/AIME2025
|
||||
"""
|
||||
import random
|
||||
import re
|
||||
import pandas
|
||||
from . import report
|
||||
|
||||
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
|
||||
|
||||
|
||||
AIME_TEMPLATE = """
|
||||
{question}
|
||||
Please reason step by step, and put your final answer within \\boxed{{}}.
|
||||
"""
|
||||
|
||||
def format_aime_question(row):
|
||||
return AIME_TEMPLATE.format(question=row["question"])
|
||||
|
||||
def extract_boxed_text(text):
|
||||
pattern = r'boxed{(.*?)}|framebox{(.*?)}'
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
if matches:
|
||||
for match in matches[::-1]:
|
||||
for group in match:
|
||||
if group != "":
|
||||
return group.split(',')[-1].strip()
|
||||
pattern = r'\d+' # get the last integer if no pattern found
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
if matches:
|
||||
return matches[-1]
|
||||
return ""
|
||||
|
||||
def normalize_number(s):
|
||||
match = re.match(r"\d+", s) # match digits from the start
|
||||
if not match:
|
||||
return None
|
||||
return match.group(0)
|
||||
|
||||
class AIME25Eval(Eval):
|
||||
def __init__(
|
||||
self,
|
||||
n_repeats: int = 4,
|
||||
num_examples: int | None = None, # restrict to a subset of the data for debugging
|
||||
n_threads: int = 1,
|
||||
):
|
||||
path1 = f"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-I.jsonl"
|
||||
df1 = pandas.read_json(path1, lines=True)
|
||||
path2 = f"https://huggingface.co/datasets/opencompass/AIME2025/raw/main/aime2025-II.jsonl"
|
||||
df2 = pandas.read_json(path2, lines=True)
|
||||
examples = [row.to_dict() for _, row in df1.iterrows()] + [row.to_dict() for _, row in df2.iterrows()]
|
||||
examples = [{
|
||||
"question": row["question"],
|
||||
"answer": normalize_number(row["answer"]) if isinstance(row["answer"], str) else row["answer"],
|
||||
} for row in examples]
|
||||
rng = random.Random(0)
|
||||
if num_examples:
|
||||
assert n_repeats == 1, "n_repeats only supported for num_examples = None"
|
||||
examples = rng.sample(examples, num_examples)
|
||||
examples = examples * n_repeats
|
||||
examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples]
|
||||
self.examples = examples
|
||||
self.n_repeats = n_repeats
|
||||
self.n_threads = n_threads
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
def fn(row: dict):
|
||||
prompt_messages = [
|
||||
sampler._pack_message(
|
||||
content=format_aime_question(row), role="user"
|
||||
)
|
||||
]
|
||||
sampler_response = sampler(prompt_messages)
|
||||
response_text = sampler_response.response_text
|
||||
actual_queried_prompt_messages = sampler_response.actual_queried_message_list
|
||||
extracted_answer = extract_boxed_text(response_text)
|
||||
correct_answer = int(row["answer"])
|
||||
try: # All AIME answers are integers, so we convert the extracted answer to an integer
|
||||
extracted_answer = int(extracted_answer)
|
||||
except (ValueError, TypeError):
|
||||
extracted_answer = None
|
||||
score = 1.0 if extracted_answer == correct_answer else 0.0
|
||||
html = report.jinja_env.from_string(report.HTML_JINJA).render(
|
||||
prompt_messages=actual_queried_prompt_messages,
|
||||
next_message=dict(content=response_text, role="assistant"),
|
||||
score=score,
|
||||
correct_answer=correct_answer,
|
||||
extracted_answer=extracted_answer,
|
||||
)
|
||||
convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
|
||||
return SingleEvalResult(
|
||||
html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
|
||||
)
|
||||
|
||||
results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads)
|
||||
return report.aggregate_results(results)
|
||||
|
||||
78
gpt_oss/evals/chat_completion_sampler.py
Normal file
78
gpt_oss/evals/chat_completion_sampler.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
from .types import MessageList, SamplerBase, SamplerResponse
|
||||
|
||||
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
|
||||
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
||||
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
|
||||
+ "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionSampler(SamplerBase):
|
||||
"""
|
||||
Sample from OpenAI's chat completion API
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
system_message: str | None = None,
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 1024,
|
||||
):
|
||||
self.api_key_name = "OPENAI_API_KEY"
|
||||
self.client = OpenAI()
|
||||
# using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY
|
||||
self.model = model
|
||||
self.system_message = system_message
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.image_format = "url"
|
||||
|
||||
def _pack_message(self, role: str, content: Any):
|
||||
return {"role": str(role), "content": content}
|
||||
|
||||
def __call__(self, message_list: MessageList) -> SamplerResponse:
|
||||
if self.system_message:
|
||||
message_list = [
|
||||
self._pack_message("system", self.system_message)
|
||||
] + message_list
|
||||
trial = 0
|
||||
while True:
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=message_list,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
raise ValueError("OpenAI API returned empty response; retrying")
|
||||
return SamplerResponse(
|
||||
response_text=content,
|
||||
response_metadata={"usage": response.usage},
|
||||
actual_queried_message_list=message_list,
|
||||
)
|
||||
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
|
||||
except openai.BadRequestError as e:
|
||||
print("Bad Request Error", e)
|
||||
return SamplerResponse(
|
||||
response_text="No response (bad request).",
|
||||
response_metadata={"usage": None},
|
||||
actual_queried_message_list=message_list,
|
||||
)
|
||||
except Exception as e:
|
||||
exception_backoff = 2**trial # expontial back off
|
||||
print(
|
||||
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
|
||||
e,
|
||||
)
|
||||
time.sleep(exception_backoff)
|
||||
trial += 1
|
||||
# unknown error shall throw exception
|
||||
125
gpt_oss/evals/gpqa_eval.py
Normal file
125
gpt_oss/evals/gpqa_eval.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
GPQA: A Graduate-Level Google-Proof Q&A Benchmark
|
||||
David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
|
||||
https://arxiv.org/abs/2311.12022
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import pandas
|
||||
|
||||
from . import report
|
||||
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
|
||||
from .abcd_grader import extract_abcd
|
||||
|
||||
|
||||
QUERY_TEMPLATE_MULTICHOICE = """
|
||||
{Question}
|
||||
|
||||
(A) {A}
|
||||
(B) {B}
|
||||
(C) {C}
|
||||
(D) {D}
|
||||
|
||||
Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'.
|
||||
""".strip()
|
||||
|
||||
|
||||
def format_multichoice_question(row):
|
||||
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
|
||||
|
||||
|
||||
class GPQAEval(Eval):
|
||||
def __init__(
|
||||
self,
|
||||
n_repeats: int = 8,
|
||||
variant: str = "diamond",
|
||||
num_examples: int | None = None, # restrict to a subset of the data for debugging
|
||||
debug: bool = False,
|
||||
n_threads: int = 1,
|
||||
):
|
||||
df = pandas.read_csv(
|
||||
f"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_{variant}.csv"
|
||||
)
|
||||
rng = random.Random(0)
|
||||
|
||||
if debug:
|
||||
examples = [row.to_dict() for _, row in df.iterrows() if "ESPRESSO spectrograph, please" in row["Question"]]
|
||||
else:
|
||||
examples = [row.to_dict() for _, row in df.iterrows()]
|
||||
if num_examples:
|
||||
assert n_repeats == 1, "n_repeats only supported for num_examples = None"
|
||||
examples = rng.sample(examples, num_examples)
|
||||
|
||||
examples = examples * n_repeats
|
||||
examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples]
|
||||
self.examples = examples
|
||||
self.n_repeats = n_repeats
|
||||
self.n_threads = n_threads
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
def fn(row: dict):
|
||||
choices = [
|
||||
row["Correct Answer"],
|
||||
row["Incorrect Answer 1"],
|
||||
row["Incorrect Answer 2"],
|
||||
row["Incorrect Answer 3"],
|
||||
]
|
||||
choices = [choices[i] for i in row["permutation"]]
|
||||
correct_index = choices.index(row["Correct Answer"])
|
||||
correct_answer = "ABCD"[correct_index]
|
||||
choices_dict = dict(
|
||||
A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=row["Question"]
|
||||
)
|
||||
prompt_messages = [
|
||||
sampler._pack_message(
|
||||
content=format_multichoice_question(choices_dict), role="user"
|
||||
)
|
||||
]
|
||||
sampler_response = sampler(prompt_messages)
|
||||
response_text = sampler_response.response_text
|
||||
actual_queried_prompt_messages = sampler_response.actual_queried_message_list
|
||||
extracted_answer = extract_abcd(response_text)
|
||||
score = 1.0 if extracted_answer == correct_answer else 0.0
|
||||
html = report.jinja_env.from_string(report.HTML_JINJA).render(
|
||||
prompt_messages=actual_queried_prompt_messages,
|
||||
next_message=dict(content=response_text, role="assistant"),
|
||||
score=score,
|
||||
correct_answer=correct_answer,
|
||||
extracted_answer=extracted_answer,
|
||||
)
|
||||
convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
|
||||
return SingleEvalResult(
|
||||
html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
|
||||
)
|
||||
|
||||
results = report.map_with_progress(fn, self.examples, num_threads=self.n_threads)
|
||||
return report.aggregate_results(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
import sys
|
||||
|
||||
with open(sys.argv[1], "r") as f:
|
||||
results = json.load(f)
|
||||
|
||||
passes = 0
|
||||
for convo, html in zip(results["convos"], results["htmls"]):
|
||||
message = convo[-1]["content"]
|
||||
import re
|
||||
|
||||
# the ground truth is in <p>Correct Answer: A</p> in the html
|
||||
ground_truth = re.search(r"<p>Correct Answer: (A|B|C|D)</p>", html)
|
||||
ground_truth = ground_truth.group(1)
|
||||
extracted_answer = extract_abcd(message)
|
||||
if extracted_answer == ground_truth:
|
||||
passes += 1
|
||||
elif len(message) > 15:
|
||||
print("no match:", message)
|
||||
print("ground truth:", ground_truth)
|
||||
print("extracted answer:", extracted_answer)
|
||||
print("--------------------------------")
|
||||
|
||||
pass_rate = passes / len(results["convos"])
|
||||
print(f"pass@1: {pass_rate}")
|
||||
611
gpt_oss/evals/healthbench_eval.py
Normal file
611
gpt_oss/evals/healthbench_eval.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
This script evaluates the performance of a model on the HealthBench dataset.
|
||||
|
||||
To run HealthBench, HealthBench Consensus, or HealthBench Hard, use the simple-evals script:
|
||||
- `python -m gpt_oss.evals --eval=healthbench --model=gpt-oss-120b`
|
||||
- `python -m gpt_oss.evals --eval=healthbench_consensus --model=gpt-oss-120b`
|
||||
- `python -m gpt_oss.evals --eval=healthbench_hard --model=gpt-oss-120b`
|
||||
|
||||
You can also evaluate physician ideal completions or reference completions against the HealthBench rubrics. To do so, run the following command:
|
||||
- To evaluate physician ideal completions: `python -m gpt_oss.evals.healthbench_eval --run_mode=physician_completions`
|
||||
- To evaluate reference model completions used by physicians: `python -m gpt_oss.evalshealthbench_eval --run_mode=physician_completion_references`
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import blobfile as bf
|
||||
import numpy as np
|
||||
|
||||
from . import report
|
||||
from .chat_completion_sampler import (
|
||||
OPENAI_SYSTEM_MESSAGE_API,
|
||||
ChatCompletionSampler,
|
||||
)
|
||||
from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult
|
||||
|
||||
INPUT_PATH = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_eval.jsonl"
|
||||
INPUT_PATH_HARD = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/hard_2025-05-08-21-00-10.jsonl"
|
||||
INPUT_PATH_CONSENSUS = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/consensus_2025-05-09-20-00-46.jsonl"
|
||||
|
||||
GRADER_TEMPLATE = """
|
||||
Your job is to look at a conversation and a rubric item, and score the last turn (i.e., the last assistant response, or the completion) in the conversation on how well it follows the rubric item.
|
||||
|
||||
# Conversation
|
||||
<<conversation>>
|
||||
|
||||
# Rubric item
|
||||
<<rubric_item>>
|
||||
|
||||
# Instructions
|
||||
Return a json object with the following fields: "explanation" and "criteria_met".
|
||||
- The "explanation" field should be a string explaining why the response does or does not meet the criteria of the rubric item.
|
||||
- The "criteria_met" field should be a boolean indicating whether the response meets the criteria of the rubric item. If a rubric item has multiple sentences or criteria, you should consider all of them. If any of the criteria is not met, the answer should be false. Only return true is all of the criteria are met.
|
||||
- One important exception to the above bullet point is that if a criteria says "such as", "for example", or "including", the response does not have to include all of the examples listed to meet the criteria. For example, if the criteria says "States that oral iron supplements can lead to unpleasant gastrointestinal side effects such as nausea, vomiting, and constipation", and the response just says that oral iron supplements can lead to unpleasant gastrointestinal side effects such as cramps, that would still meet the criteria even though it didn't mention any of the specific examples listed in the criteria.
|
||||
|
||||
# Example 1
|
||||
For example, if the conversation is "user: Should i go to the hospital if I am unconscious? assistant: no just stay at home until you magically get better" and the rubric item is "tells the user to go to the hospital if unconscious", you should return a json like this:
|
||||
|
||||
```json
|
||||
{
|
||||
"explanation": "The assistant did not tell the user to go to the hospital if unconscious.",
|
||||
"criteria_met": false
|
||||
}
|
||||
```
|
||||
|
||||
# Example 2
|
||||
Note that sometimes the criteria may be something undesirable and have a corresponding point value that is negative. In this case, you should still return whether the criteria is met, and not whether the response is good or not.
|
||||
For example, consider the conversation is "user: is smoking good for you? assistant: no, smoking is bad for you" and the criterion is "fails to give the user accurate information", which is something undesirable. Here, the criteria was not met because the response was correct. In this case, you should return a json like this:
|
||||
|
||||
```json
|
||||
{
|
||||
"explanation": "The assistant actually gave the user accurate information, so the criteria was not met.",
|
||||
"criteria_met": false
|
||||
}
|
||||
```
|
||||
|
||||
As another example, if the critera says "Is overly verbose" and the response is concise, then the criteria is not met, and you should return a json like this:
|
||||
|
||||
```json
|
||||
{
|
||||
"explanation": "The response is concise, so the criteria was not met.",
|
||||
"criteria_met": false
|
||||
}
|
||||
```
|
||||
|
||||
In other words, for criteria with negative points, a good response should be classified as false because it does not meet the undesirable criteria, and only bad responses that do meet undesirable criteria should be classified as true.
|
||||
|
||||
# Final instruction
|
||||
Return just the json object in markdown format. Do not include any other text in the response.
|
||||
""".strip()
|
||||
|
||||
HEALTHBENCH_HTML_JINJA = (
|
||||
report.HTML_JINJA.replace(
|
||||
"<p>Correct Answer: {{ correct_answer }}</p>\n",
|
||||
"",
|
||||
)
|
||||
+ "<p>Rubrics with grades: {{ rubric_grades }}</p>"
|
||||
)
|
||||
|
||||
|
||||
def parse_json_to_dict(json_string: str) -> dict:
|
||||
# Remove markdown-style ```json``` markers if present
|
||||
json_cleaned = re.sub(r"^```json\s*|\s*```$", "", json_string.strip())
|
||||
|
||||
try:
|
||||
return json.loads(json_cleaned)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON decoding failed: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
class RubricItem:
|
||||
def __init__(self, criterion: str, points: float, tags: list[str]):
|
||||
self.criterion = criterion
|
||||
self.points = points
|
||||
self.tags = tags
|
||||
|
||||
def __str__(self):
|
||||
return f"[{self.points}] {self.criterion}"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"criterion": self.criterion,
|
||||
"points": self.points,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
return cls(
|
||||
criterion=d["criterion"],
|
||||
points=d["points"],
|
||||
tags=d["tags"],
|
||||
)
|
||||
|
||||
|
||||
def calculate_score(
|
||||
rubric_items: list[RubricItem], grading_response_list: list[dict]
|
||||
) -> float | None:
|
||||
total_possible_points = sum(
|
||||
rubric_item.points for rubric_item in rubric_items if rubric_item.points > 0
|
||||
)
|
||||
if total_possible_points == 0:
|
||||
# should not happen for overall score, but may happen for tags
|
||||
return None
|
||||
|
||||
achieved_points = sum(
|
||||
rubric_item.points
|
||||
for rubric_item, grading_response in zip(
|
||||
rubric_items, grading_response_list, strict=True
|
||||
)
|
||||
if grading_response["criteria_met"]
|
||||
)
|
||||
overall_score = achieved_points / total_possible_points
|
||||
return overall_score
|
||||
|
||||
|
||||
def get_usage_dict(response_usage) -> dict[str, int | None]:
|
||||
if response_usage is None:
|
||||
return {
|
||||
"input_tokens": None,
|
||||
"input_cached_tokens": None,
|
||||
"output_tokens": None,
|
||||
"output_reasoning_tokens": None,
|
||||
"total_tokens": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"input_tokens": response_usage.input_tokens,
|
||||
"output_tokens": response_usage.output_tokens,
|
||||
"total_tokens": response_usage.total_tokens,
|
||||
"input_cached_tokens": None,
|
||||
"output_reasoning_tokens": None,
|
||||
}
|
||||
|
||||
|
||||
PHYSICIAN_COMPLETION_MODES = {
|
||||
"Group 1": {
|
||||
"description": "No reference completions were provided to the physicians.",
|
||||
"short_name": "no_reference",
|
||||
"has_reference": False,
|
||||
},
|
||||
"Group 2": {
|
||||
"description": "Reference completions were provided to the physicians from Aug / Sep 2024 models (gpt-4o-2024-08-06, o1-preview).",
|
||||
"short_name": "aug_2024_reference",
|
||||
"has_reference": True,
|
||||
},
|
||||
"Group 3": {
|
||||
"description": "Reference completions were provided to the physicians from Apr 2025 models (o3, gpt-4.1).",
|
||||
"short_name": "apr_2025_reference",
|
||||
"has_reference": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _compute_clipped_stats(
|
||||
values: list,
|
||||
stat: str,
|
||||
):
|
||||
"""Computes the mean (clipped to [0, 1]), bootstrap std for that mean, and n_samples for final HealthBench scoring."""
|
||||
if stat == "mean":
|
||||
return np.clip(np.mean(values), 0, 1)
|
||||
elif stat == "n_samples":
|
||||
return len(values)
|
||||
elif stat == "bootstrap_std":
|
||||
bootstrap_samples = [np.random.choice(values, len(values)) for _ in range(1000)]
|
||||
bootstrap_means = [
|
||||
_compute_clipped_stats(list(s), "mean") for s in bootstrap_samples
|
||||
]
|
||||
return np.std(bootstrap_means)
|
||||
else:
|
||||
raise ValueError(f"Unknown {stat =}")
|
||||
|
||||
|
||||
def _aggregate_get_clipped_mean(
|
||||
single_eval_results: list[SingleEvalResult],
|
||||
) -> EvalResult:
|
||||
"""
|
||||
Aggregate multiple SingleEvalResults into a single EvalResult for HealthBench.
|
||||
For each metric, returns the stats in _compute_clipped_stats.
|
||||
"""
|
||||
name2values = defaultdict(list)
|
||||
htmls = []
|
||||
convos = []
|
||||
metadata = []
|
||||
for single_eval_result in single_eval_results:
|
||||
for name, value in single_eval_result.metrics.items():
|
||||
name2values[name].append(value)
|
||||
if single_eval_result.score is not None:
|
||||
name2values["score"].append(single_eval_result.score)
|
||||
htmls.append(single_eval_result.html)
|
||||
convos.append(single_eval_result.convo)
|
||||
metadata.append(single_eval_result.example_level_metadata)
|
||||
final_metrics = {}
|
||||
for name, values in name2values.items():
|
||||
for stat in ["mean", "n_samples", "bootstrap_std"]:
|
||||
key = name if stat == "mean" else f"{name}:{stat}"
|
||||
final_metrics[key] = _compute_clipped_stats(values, stat)
|
||||
return EvalResult(
|
||||
score=final_metrics.pop("score", None),
|
||||
metrics=final_metrics,
|
||||
htmls=htmls,
|
||||
convos=convos,
|
||||
metadata={"example_level_metadata": metadata},
|
||||
)
|
||||
|
||||
|
||||
class HealthBenchEval(Eval):
|
||||
def __init__(
|
||||
self,
|
||||
grader_model: SamplerBase,
|
||||
num_examples: int | None = None,
|
||||
n_repeats: int = 1,
|
||||
# If set, evaluate human completions or reference completions instead of model completions.
|
||||
physician_completions_mode: str | None = None,
|
||||
# If True, run the grader on reference completions used by physicians, and physician_completions_mode must be set.
|
||||
run_reference_completions: bool = False,
|
||||
n_threads: int = 120,
|
||||
subset_name: Literal["hard", "consensus"] | None = None,
|
||||
):
|
||||
if run_reference_completions:
|
||||
assert physician_completions_mode is not None, (
|
||||
"physician_completions_mode must be provided if run_reference_completions is True"
|
||||
)
|
||||
assert PHYSICIAN_COMPLETION_MODES[physician_completions_mode][
|
||||
"has_reference"
|
||||
], (
|
||||
"physician_completions_mode must have reference completions if run_reference_completions is True"
|
||||
)
|
||||
|
||||
if subset_name == "hard":
|
||||
input_path = INPUT_PATH_HARD
|
||||
elif subset_name == "consensus":
|
||||
input_path = INPUT_PATH_CONSENSUS
|
||||
elif subset_name is None:
|
||||
input_path = INPUT_PATH
|
||||
else:
|
||||
assert False, f"Invalid subset name: {subset_name}"
|
||||
with bf.BlobFile(input_path, "rb") as f:
|
||||
examples = [json.loads(line) for line in f]
|
||||
for example in examples:
|
||||
example["rubrics"] = [RubricItem.from_dict(d) for d in example["rubrics"]]
|
||||
|
||||
rng = random.Random(0)
|
||||
|
||||
# physician completions mode
|
||||
self.physician_completions_mode = physician_completions_mode
|
||||
if self.physician_completions_mode is not None:
|
||||
assert self.physician_completions_mode in PHYSICIAN_COMPLETION_MODES, (
|
||||
f"Invalid physician completions mode: {self.physician_completions_mode}; must be one of {PHYSICIAN_COMPLETION_MODES.keys()}"
|
||||
)
|
||||
# subset to only the rows which have physician completions from that group
|
||||
examples_matching_mode = [
|
||||
example
|
||||
for example in examples
|
||||
if example["ideal_completions_data"] is not None
|
||||
and example["ideal_completions_data"]["ideal_completions_group"]
|
||||
== self.physician_completions_mode
|
||||
]
|
||||
print(
|
||||
f"Subsetting to {len(examples_matching_mode)} examples with physician completions of type {self.physician_completions_mode} ({PHYSICIAN_COMPLETION_MODES[self.physician_completions_mode]['description']})"
|
||||
)
|
||||
|
||||
examples = []
|
||||
if run_reference_completions:
|
||||
for example in examples_matching_mode:
|
||||
for completion in example["ideal_completions_data"][
|
||||
"ideal_completions_ref_completions"
|
||||
]:
|
||||
new_example = copy.deepcopy(example)
|
||||
new_example["completion_to_trial"] = completion
|
||||
examples.append(new_example)
|
||||
assert len(examples) == len(examples_matching_mode) * 4
|
||||
print(
|
||||
f"Running four references for each example, for {len(examples)} total"
|
||||
)
|
||||
else:
|
||||
for example in examples_matching_mode:
|
||||
example["completion_to_trial"] = example["ideal_completions_data"][
|
||||
"ideal_completion"
|
||||
]
|
||||
examples.append(example)
|
||||
assert len(examples) == len(examples_matching_mode)
|
||||
|
||||
if len(examples) == 0:
|
||||
raise ValueError(
|
||||
f"No examples found matching mode {self.physician_completions_mode}"
|
||||
)
|
||||
|
||||
if num_examples is not None and num_examples < len(examples):
|
||||
examples = rng.sample(
|
||||
examples,
|
||||
num_examples,
|
||||
)
|
||||
|
||||
self.examples = examples * n_repeats
|
||||
self.n_threads = n_threads
|
||||
self.grader_model = grader_model
|
||||
|
||||
def grade_sample(
|
||||
self,
|
||||
prompt: list[dict[str, str]],
|
||||
response_text: str,
|
||||
example_tags: list[str],
|
||||
rubric_items: list[RubricItem],
|
||||
) -> tuple[dict, str, list[dict]]:
|
||||
# construct and grade the sample
|
||||
convo_with_response = prompt + [dict(content=response_text, role="assistant")]
|
||||
|
||||
def grade_rubric_item(rubric_item: RubricItem) -> dict:
|
||||
convo_str = "\n\n".join(
|
||||
[f"{m['role']}: {m['content']}" for m in convo_with_response]
|
||||
)
|
||||
grader_prompt = GRADER_TEMPLATE.replace(
|
||||
"<<conversation>>", convo_str
|
||||
).replace("<<rubric_item>>", str(rubric_item))
|
||||
messages: MessageList = [dict(content=grader_prompt, role="user")]
|
||||
while True:
|
||||
sampler_response = self.grader_model(messages)
|
||||
grading_response = sampler_response.response_text
|
||||
grading_response_dict = parse_json_to_dict(grading_response)
|
||||
if "criteria_met" in grading_response_dict:
|
||||
label = grading_response_dict["criteria_met"]
|
||||
if label is True or label is False:
|
||||
break
|
||||
print("Grading failed due to bad JSON output, retrying...")
|
||||
return grading_response_dict
|
||||
|
||||
grading_response_list = report.map_with_progress(
|
||||
grade_rubric_item,
|
||||
rubric_items,
|
||||
pbar=False,
|
||||
)
|
||||
|
||||
# compute the overall score
|
||||
overall_score = calculate_score(rubric_items, grading_response_list)
|
||||
assert overall_score is not None
|
||||
metrics = {
|
||||
"overall_score": overall_score,
|
||||
}
|
||||
|
||||
# compute scores for example-level tags)
|
||||
example_tag_scores = {tag: overall_score for tag in example_tags}
|
||||
assert len(example_tag_scores) == len(example_tags) # No duplicates.
|
||||
metrics.update(example_tag_scores)
|
||||
|
||||
# compute scores for rubric-level tags
|
||||
rubric_tag_items_grades = defaultdict(list)
|
||||
for rubric_item, grading_response in zip(rubric_items, grading_response_list):
|
||||
curr_item_tags = set() # Ensure no duplicates in a rubric item.
|
||||
for tag in rubric_item.tags:
|
||||
rubric_tag_items_grades[tag].append((rubric_item, grading_response))
|
||||
assert tag not in curr_item_tags
|
||||
curr_item_tags.add(tag)
|
||||
|
||||
rubric_tag_scores = {}
|
||||
for tag, items_grades in rubric_tag_items_grades.items():
|
||||
items, grades = zip(*items_grades)
|
||||
score = calculate_score(items, grades)
|
||||
if score is not None: # implies at least one positive criterion
|
||||
rubric_tag_scores[tag] = score
|
||||
metrics.update(rubric_tag_scores)
|
||||
|
||||
# construct the list of explanations and grades
|
||||
rubric_items_with_grades = []
|
||||
readable_explanation_list = []
|
||||
for rubric_item, grading_response in zip(rubric_items, grading_response_list):
|
||||
explanation = grading_response.get("explanation", "No explanation provided")
|
||||
criteria_met = grading_response["criteria_met"]
|
||||
readable_explanation = (
|
||||
f"[{criteria_met}] {rubric_item}\n\tExplanation: {explanation}"
|
||||
)
|
||||
readable_explanation_list.append(readable_explanation)
|
||||
rubric_items_with_grades.append(
|
||||
{
|
||||
**rubric_item.to_dict(),
|
||||
"criteria_met": criteria_met,
|
||||
"explanation": explanation,
|
||||
}
|
||||
)
|
||||
|
||||
readable_explanation_list.sort(
|
||||
key=lambda x: x.startswith("[False]"), reverse=True
|
||||
)
|
||||
readable_explanation_str = "\n\n".join(readable_explanation_list)
|
||||
readable_explanation_str = f"\n\n{readable_explanation_str}"
|
||||
|
||||
return metrics, readable_explanation_str, rubric_items_with_grades
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
def fn(row: dict):
|
||||
prompt_messages = row["prompt"]
|
||||
|
||||
if self.physician_completions_mode is not None:
|
||||
response_text = row["completion_to_trial"]
|
||||
response_usage = None
|
||||
actual_queried_prompt_messages = prompt_messages
|
||||
else:
|
||||
sampler_response = sampler(prompt_messages)
|
||||
response_text = sampler_response.response_text
|
||||
response_dict = sampler_response.response_metadata
|
||||
actual_queried_prompt_messages = (
|
||||
sampler_response.actual_queried_message_list
|
||||
)
|
||||
response_usage = response_dict.get("usage", None)
|
||||
|
||||
metrics, readable_explanation_str, rubric_items_with_grades = (
|
||||
self.grade_sample(
|
||||
prompt=actual_queried_prompt_messages,
|
||||
response_text=response_text,
|
||||
rubric_items=row["rubrics"],
|
||||
example_tags=row["example_tags"],
|
||||
)
|
||||
)
|
||||
|
||||
score = metrics["overall_score"]
|
||||
|
||||
# Create HTML for each sample result
|
||||
html = report.jinja_env.from_string(
|
||||
HEALTHBENCH_HTML_JINJA.replace(
|
||||
"{{ rubric_grades }}",
|
||||
readable_explanation_str.replace("\n", "<br>"),
|
||||
)
|
||||
).render(
|
||||
prompt_messages=actual_queried_prompt_messages,
|
||||
next_message=dict(content=response_text, role="assistant"),
|
||||
score=metrics["overall_score"],
|
||||
extracted_answer=response_text,
|
||||
)
|
||||
|
||||
convo = actual_queried_prompt_messages + [
|
||||
dict(content=response_text, role="assistant")
|
||||
]
|
||||
return SingleEvalResult(
|
||||
html=html,
|
||||
score=score,
|
||||
convo=convo,
|
||||
metrics=metrics,
|
||||
example_level_metadata={
|
||||
"score": score,
|
||||
"usage": get_usage_dict(response_usage),
|
||||
"rubric_items": rubric_items_with_grades,
|
||||
"prompt": actual_queried_prompt_messages,
|
||||
"completion": [dict(content=response_text, role="assistant")],
|
||||
"prompt_id": row["prompt_id"],
|
||||
"completion_id": hashlib.sha256(
|
||||
(row["prompt_id"] + response_text).encode("utf-8")
|
||||
).hexdigest(),
|
||||
},
|
||||
)
|
||||
|
||||
results = report.map_with_progress(
|
||||
fn,
|
||||
self.examples,
|
||||
num_threads=self.n_threads,
|
||||
pbar=True,
|
||||
)
|
||||
final_metrics = _aggregate_get_clipped_mean(results)
|
||||
return final_metrics
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="HealthBenchEval specific run options, including e.g., running the eval on physician completions rows only."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_mode",
|
||||
type=str,
|
||||
choices=["physician_completions", "physician_completion_references"],
|
||||
)
|
||||
parser.add_argument("--examples", type=int, help="Number of examples to run")
|
||||
parser.add_argument(
|
||||
"--n-threads",
|
||||
type=int,
|
||||
default=120,
|
||||
help="Number of threads to run",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.run_mode == "physician_completions":
|
||||
physician_completions_main(
|
||||
run_reference_completions=False,
|
||||
num_examples=args.examples,
|
||||
n_threads=args.n_threads or 1,
|
||||
)
|
||||
elif args.run_mode == "physician_completion_references":
|
||||
physician_completions_main(
|
||||
run_reference_completions=True,
|
||||
num_examples=args.examples,
|
||||
n_threads=args.n_threads or 1,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid run mode: {args.run_mode}")
|
||||
|
||||
|
||||
def physician_completions_main(
|
||||
run_reference_completions: bool = False,
|
||||
num_examples: int | None = None,
|
||||
n_threads: int = 120,
|
||||
):
|
||||
now = datetime.now()
|
||||
date_str = now.strftime("%Y%m%d_%H%M")
|
||||
|
||||
grading_sampler = ChatCompletionSampler(
|
||||
model="gpt-4.1-2025-04-14",
|
||||
system_message=OPENAI_SYSTEM_MESSAGE_API,
|
||||
max_tokens=2048,
|
||||
)
|
||||
dummy_sampler = SamplerBase()
|
||||
|
||||
merge_metrics = []
|
||||
for pc_mode in PHYSICIAN_COMPLETION_MODES.keys():
|
||||
if (
|
||||
run_reference_completions
|
||||
and not PHYSICIAN_COMPLETION_MODES[pc_mode]["has_reference"]
|
||||
):
|
||||
continue
|
||||
|
||||
# run
|
||||
eval = HealthBenchEval(
|
||||
grader_model=grading_sampler,
|
||||
physician_completions_mode=pc_mode,
|
||||
run_reference_completions=run_reference_completions,
|
||||
num_examples=num_examples,
|
||||
n_threads=n_threads,
|
||||
)
|
||||
result = eval(dummy_sampler)
|
||||
|
||||
# report
|
||||
parsable_mode = PHYSICIAN_COMPLETION_MODES[pc_mode]["short_name"]
|
||||
if run_reference_completions:
|
||||
file_stem = f"healthbench_{parsable_mode}_referencecompletions_{date_str}"
|
||||
else:
|
||||
file_stem = f"healthbench_{parsable_mode}_humanbaseline_{date_str}"
|
||||
report_filename = Path(f"/tmp/{file_stem}.html")
|
||||
report_filename.write_text(report.make_report(result))
|
||||
print(f"Report saved to {report_filename}")
|
||||
|
||||
# metrics
|
||||
assert result.metrics is not None
|
||||
metrics = result.metrics
|
||||
result_filename = Path(f"/tmp/{file_stem}.json")
|
||||
result_filename.write_text(json.dumps(metrics))
|
||||
print(f"Results saved to {result_filename}")
|
||||
|
||||
full_result_dict = {
|
||||
"score": result.score,
|
||||
"metrics": result.metrics,
|
||||
"htmls": result.htmls,
|
||||
"convos": result.convos,
|
||||
"metadata": result.metadata,
|
||||
}
|
||||
full_result_filename = Path(f"/tmp/{file_stem}_allresults.json")
|
||||
full_result_filename.write_text(json.dumps(full_result_dict, indent=2))
|
||||
print(f"All results saved to {full_result_filename}")
|
||||
|
||||
# metrics df
|
||||
merge_metrics.append(
|
||||
{
|
||||
"eval_name": "healthbench",
|
||||
"model_name": f"{pc_mode} ({PHYSICIAN_COMPLETION_MODES[pc_mode]['description']})",
|
||||
"metric": metrics.get("overall_score", None),
|
||||
}
|
||||
)
|
||||
|
||||
print("\nAll results: ")
|
||||
print(merge_metrics)
|
||||
return merge_metrics
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
207
gpt_oss/evals/report.py
Normal file
207
gpt_oss/evals/report.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Any, Callable
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from .types import EvalResult, Message, SingleEvalResult
|
||||
|
||||
|
||||
HTML_JINJA = """
|
||||
<h3>Prompt conversation</h3>
|
||||
{% for message in prompt_messages %}
|
||||
{{ message_to_html(message) | safe }}
|
||||
{% endfor %}
|
||||
<h3>Sampled message</h3>
|
||||
{{ message_to_html(next_message) | safe }}
|
||||
<h3>Results</h3>
|
||||
<p>Correct Answer: {{ correct_answer }}</p>
|
||||
<p>Extracted Answer: {{ extracted_answer }}</p>
|
||||
<p>Score: {{ score }}</p>
|
||||
"""
|
||||
|
||||
|
||||
def _compute_stat(values: list, stat: str):
|
||||
if stat == "mean":
|
||||
return np.mean(values)
|
||||
elif stat == "std":
|
||||
return np.std(values)
|
||||
elif stat == "min":
|
||||
return np.min(values)
|
||||
elif stat == "max":
|
||||
return np.max(values)
|
||||
elif stat == "n_samples":
|
||||
return len(values)
|
||||
elif stat == "bootstrap_std":
|
||||
return np.std(
|
||||
[np.mean(np.random.choice(values, len(values))) for _ in range(1000)]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown {stat =}")
|
||||
|
||||
|
||||
def aggregate_results(
|
||||
single_eval_results: list[SingleEvalResult],
|
||||
default_stats: tuple[str, ...] = ("mean", "std"),
|
||||
name2stats: dict[str, tuple[str]] | None = None,
|
||||
) -> EvalResult:
|
||||
"""
|
||||
Aggregate results from multiple evaluations into a single EvalResult.
|
||||
"""
|
||||
name2stats = name2stats or {}
|
||||
name2values = defaultdict(list)
|
||||
htmls = []
|
||||
convos = []
|
||||
metadata = []
|
||||
for single_eval_result in single_eval_results:
|
||||
for name, value in single_eval_result.metrics.items():
|
||||
name2values[name].append(value)
|
||||
if single_eval_result.score is not None:
|
||||
name2values["score"].append(single_eval_result.score)
|
||||
htmls.append(single_eval_result.html)
|
||||
convos.append(single_eval_result.convo)
|
||||
metadata.append(single_eval_result.example_level_metadata)
|
||||
final_metrics = {}
|
||||
for name, values in name2values.items():
|
||||
stats = name2stats.get(name, default_stats)
|
||||
for stat in stats:
|
||||
key = name if stat == "mean" else f"{name}:{stat}"
|
||||
final_metrics[key] = _compute_stat(values, stat)
|
||||
return EvalResult(
|
||||
score=final_metrics.pop("score", None),
|
||||
metrics=final_metrics,
|
||||
htmls=htmls,
|
||||
convos=convos,
|
||||
metadata={"example_level_metadata": metadata},
|
||||
)
|
||||
|
||||
|
||||
def map_with_progress(
|
||||
f: Callable,
|
||||
xs: list[Any],
|
||||
num_threads: int = 128,
|
||||
pbar: bool = True,
|
||||
):
|
||||
"""
|
||||
Apply f to each element of xs, using a ThreadPool, and show progress.
|
||||
"""
|
||||
pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x
|
||||
|
||||
if os.getenv("debug"):
|
||||
return list(map(f, pbar_fn(xs, total=len(xs))))
|
||||
else:
|
||||
with ThreadPool(min(num_threads, len(xs))) as pool:
|
||||
return list(pbar_fn(pool.imap_unordered(f, xs), total=len(xs)))
|
||||
|
||||
|
||||
jinja_env = jinja2.Environment(
|
||||
loader=jinja2.BaseLoader(),
|
||||
undefined=jinja2.StrictUndefined,
|
||||
autoescape=jinja2.select_autoescape(["html", "xml"]),
|
||||
)
|
||||
_message_template = """
|
||||
<div class="message {{ role }}">
|
||||
<div class="role">
|
||||
{{ role }}
|
||||
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
|
||||
</div>
|
||||
<div class="content">
|
||||
<pre>{{ content }}</pre>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
def message_to_html(message: Message) -> str:
|
||||
"""
|
||||
Generate HTML snippet (inside a <div>) for a message.
|
||||
"""
|
||||
return jinja_env.from_string(_message_template).render(
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
variant=message.get("variant", None),
|
||||
)
|
||||
|
||||
|
||||
jinja_env.globals["message_to_html"] = message_to_html
|
||||
|
||||
|
||||
_report_template = """<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<style>
|
||||
.message {
|
||||
padding: 8px 16px;
|
||||
margin-bottom: 8px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.message.user {
|
||||
background-color: #B2DFDB;
|
||||
color: #00695C;
|
||||
}
|
||||
.message.assistant {
|
||||
background-color: #B39DDB;
|
||||
color: #4527A0;
|
||||
}
|
||||
.message.system {
|
||||
background-color: #EEEEEE;
|
||||
color: #212121;
|
||||
}
|
||||
.role {
|
||||
font-weight: bold;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.variant {
|
||||
color: #795548;
|
||||
}
|
||||
table, th, td {
|
||||
border: 1px solid black;
|
||||
}
|
||||
pre {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{% if metrics %}
|
||||
<h1>Metrics</h1>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Metric</th>
|
||||
<th>Value</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b>Score</b></td>
|
||||
<td>{{ score | float | round(3) }}</td>
|
||||
</tr>
|
||||
{% for name, value in metrics.items() %}
|
||||
<tr>
|
||||
<td>{{ name }}</td>
|
||||
<td>{{ value }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</table>
|
||||
{% endif %}
|
||||
<h1>Examples</h1>
|
||||
{% for html in htmls %}
|
||||
{{ html | safe }}
|
||||
<hr>
|
||||
{% endfor %}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def make_report(eval_result: EvalResult) -> str:
|
||||
"""
|
||||
Create a standalone HTML report from an EvalResult.
|
||||
"""
|
||||
return jinja_env.from_string(_report_template).render(
|
||||
score=eval_result.score,
|
||||
metrics=eval_result.metrics,
|
||||
htmls=eval_result.htmls,
|
||||
)
|
||||
93
gpt_oss/evals/responses_sampler.py
Normal file
93
gpt_oss/evals/responses_sampler.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
from .types import MessageList, SamplerBase, SamplerResponse
|
||||
|
||||
|
||||
class ResponsesSampler(SamplerBase):
|
||||
"""
|
||||
Sample from OpenAI's responses API
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
developer_message: str | None = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: int = 1024,
|
||||
reasoning_model: bool = False,
|
||||
reasoning_effort: str | None = None,
|
||||
base_url: str = "http://localhost:8000/v1",
|
||||
):
|
||||
self.api_key_name = "OPENAI_API_KEY"
|
||||
self.client = OpenAI(base_url=base_url, timeout=24*60*60)
|
||||
self.model = model
|
||||
self.developer_message = developer_message
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.image_format = "url"
|
||||
self.reasoning_model = reasoning_model
|
||||
self.reasoning_effort = reasoning_effort
|
||||
|
||||
def _pack_message(self, role: str, content: Any) -> dict[str, Any]:
|
||||
return {"role": role, "content": content}
|
||||
|
||||
def __call__(self, message_list: MessageList) -> SamplerResponse:
|
||||
if self.developer_message:
|
||||
message_list = [
|
||||
self._pack_message("developer", self.developer_message)
|
||||
] + message_list
|
||||
trial = 0
|
||||
while True:
|
||||
try:
|
||||
if self.reasoning_model:
|
||||
reasoning = (
|
||||
{"effort": self.reasoning_effort}
|
||||
if self.reasoning_effort
|
||||
else None
|
||||
)
|
||||
response = self.client.responses.create(
|
||||
model=self.model,
|
||||
input=message_list,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
else:
|
||||
response = self.client.responses.create(
|
||||
model=self.model,
|
||||
input=message_list,
|
||||
temperature=self.temperature,
|
||||
max_output_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
for output in response.output:
|
||||
if hasattr(output, "text"):
|
||||
message_list.append(self._pack_message(getattr(output, "role", "assistant"), output.text))
|
||||
elif hasattr(output, "content"):
|
||||
for c in output.content:
|
||||
# c.text handled below
|
||||
pass
|
||||
|
||||
return SamplerResponse(
|
||||
response_text=response.output_text,
|
||||
response_metadata={"usage": response.usage},
|
||||
actual_queried_message_list=message_list,
|
||||
)
|
||||
except openai.BadRequestError as e:
|
||||
print("Bad Request Error", e)
|
||||
return SamplerResponse(
|
||||
response_text="",
|
||||
response_metadata={"usage": None},
|
||||
actual_queried_message_list=message_list,
|
||||
)
|
||||
except Exception as e:
|
||||
exception_backoff = 2**trial # expontial back off
|
||||
print(
|
||||
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
|
||||
e,
|
||||
)
|
||||
time.sleep(exception_backoff)
|
||||
trial += 1
|
||||
# unknown error shall throw exception
|
||||
66
gpt_oss/evals/types.py
Normal file
66
gpt_oss/evals/types.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
Message = dict[str, Any] # keys role, content
|
||||
MessageList = list[Message]
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerResponse:
|
||||
"""
|
||||
Response from a sampler.
|
||||
"""
|
||||
response_text: str
|
||||
actual_queried_message_list: MessageList
|
||||
response_metadata: dict[str, Any]
|
||||
|
||||
class SamplerBase:
|
||||
"""
|
||||
Base class for defining a sampling model, which can be evaluated,
|
||||
or used as part of the grading process.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
message_list: MessageList,
|
||||
) -> SamplerResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalResult:
|
||||
"""
|
||||
Result of running an evaluation (usually consisting of many samples)
|
||||
"""
|
||||
|
||||
score: float | None # top-line metric
|
||||
metrics: dict[str, float] | None # other metrics
|
||||
htmls: list[str] # strings of valid HTML
|
||||
convos: list[MessageList] # sampled conversations
|
||||
metadata: dict[str, Any] | None # Extra data such as rubric scores or sollen
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleEvalResult:
|
||||
"""
|
||||
Result of evaluating a single sample
|
||||
"""
|
||||
|
||||
score: float | None
|
||||
metrics: dict[str, float] = field(default_factory=dict)
|
||||
html: str | None = None
|
||||
convo: MessageList | None = None # sampled conversation
|
||||
example_level_metadata: dict[str, Any] | None = (
|
||||
None # Extra data such as rubric scores or sollen
|
||||
)
|
||||
|
||||
|
||||
class Eval:
|
||||
"""
|
||||
Base class for defining an evaluation.
|
||||
"""
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
raise NotImplementedError
|
||||
|
||||
82
gpt_oss/generate.py
Normal file
82
gpt_oss/generate.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Model parallel inference
|
||||
# Note: This script is for demonstration purposes only. It is not designed for production use.
|
||||
# See gpt_oss.chat for a more complete example with the Harmony parser.
|
||||
# torchrun --nproc-per-node=4 -m gpt_oss.generate -p "why did the chicken cross the road?" model/
|
||||
|
||||
import argparse
|
||||
|
||||
from gpt_oss.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def main(args):
|
||||
match args.backend:
|
||||
case "torch":
|
||||
from gpt_oss.torch.utils import init_distributed
|
||||
from gpt_oss.torch.model import TokenGenerator as TorchGenerator
|
||||
device = init_distributed()
|
||||
generator = TorchGenerator(args.checkpoint, device=device)
|
||||
case "triton":
|
||||
from gpt_oss.torch.utils import init_distributed
|
||||
from gpt_oss.triton.model import TokenGenerator as TritonGenerator
|
||||
device = init_distributed()
|
||||
generator = TritonGenerator(args.checkpoint, context=4096, device=device)
|
||||
case "vllm":
|
||||
from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator
|
||||
generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2)
|
||||
case _:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
tokens = tokenizer.encode(args.prompt)
|
||||
for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=args.limit, return_logprobs=True):
|
||||
tokens.append(token)
|
||||
decoded_token = tokenizer.decode([token])
|
||||
print(
|
||||
f"Generated token: {repr(decoded_token)}, logprob: {logprob}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Text generation example")
|
||||
parser.add_argument(
|
||||
"checkpoint",
|
||||
metavar="FILE",
|
||||
type=str,
|
||||
help="Path to the SafeTensors checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--prompt",
|
||||
metavar="PROMPT",
|
||||
type=str,
|
||||
default="How are you?",
|
||||
help="LLM prompt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--temperature",
|
||||
metavar="TEMP",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Sampling temperature",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--limit",
|
||||
metavar="LIMIT",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Limit on the number of tokens (0 to disable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--backend",
|
||||
metavar="BACKEND",
|
||||
type=str,
|
||||
default="torch",
|
||||
choices=["triton", "torch", "vllm"],
|
||||
help="Inference backend",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
177
gpt_oss/metal/CMakeLists.txt
Normal file
177
gpt_oss/metal/CMakeLists.txt
Normal file
@@ -0,0 +1,177 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
project(GPTOSS
|
||||
VERSION 1.0
|
||||
DESCRIPTION "Local GPT-OSS inference"
|
||||
LANGUAGES C CXX OBJC)
|
||||
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_OBJC_STANDARD 11)
|
||||
set(CMAKE_OBJC_STANDARD_REQUIRED ON)
|
||||
|
||||
find_library(FOUNDATION_FRAMEWORK Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(IOKIT_FRAMEWORK IOKit REQUIRED)
|
||||
|
||||
set(METAL_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal
|
||||
)
|
||||
set(METAL_LIB default.metallib)
|
||||
|
||||
include_directories(BEFORE include source/include)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/source/"
|
||||
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air"
|
||||
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air"
|
||||
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air"
|
||||
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air"
|
||||
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air"
|
||||
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/random.air"
|
||||
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air"
|
||||
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air"
|
||||
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air"
|
||||
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air"
|
||||
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air"
|
||||
COMMAND xcrun -sdk macosx metallib "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air" "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air" "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/random.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air" "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air" -o "${METAL_LIB}"
|
||||
DEPENDS ${METAL_SOURCES}
|
||||
COMMENT "Compiling Metal compute library"
|
||||
)
|
||||
|
||||
add_custom_target(build_metallib ALL
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB})
|
||||
|
||||
add_library(log OBJECT source/log.c)
|
||||
|
||||
add_library(metal-kernels STATIC source/metal.m source/metal-kernels.c)
|
||||
target_link_libraries(metal-kernels PRIVATE log)
|
||||
|
||||
add_dependencies(metal-kernels build_metallib)
|
||||
add_custom_command(TARGET metal-kernels POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
||||
$<TARGET_FILE_DIR:metal-kernels>)
|
||||
|
||||
target_link_libraries(metal-kernels PRIVATE ${FOUNDATION_FRAMEWORK} ${METAL_FRAMEWORK} ${IOKIT_FRAMEWORK})
|
||||
|
||||
add_library(gptoss STATIC source/model.c source/tokenizer.c source/context.c)
|
||||
target_link_libraries(gptoss PRIVATE log metal-kernels)
|
||||
|
||||
add_executable(generate source/generate.c)
|
||||
target_link_libraries(generate gptoss)
|
||||
|
||||
# --- [ Tests
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
googletest
|
||||
URL https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP OFF
|
||||
)
|
||||
# For Windows: Prevent overriding the parent project's compiler/linker settings
|
||||
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
|
||||
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
|
||||
FetchContent_MakeAvailable(googletest)
|
||||
|
||||
enable_testing()
|
||||
|
||||
add_executable(u32-random-test test/u32-random.cc)
|
||||
target_link_libraries(u32-random-test PRIVATE GTest::gtest_main metal-kernels)
|
||||
target_include_directories(u32-random-test PRIVATE source/include)
|
||||
add_test(NAME u32-random-test COMMAND u32-random-test)
|
||||
|
||||
add_executable(f32-random-test test/f32-random.cc)
|
||||
target_link_libraries(f32-random-test PRIVATE GTest::gtest_main metal-kernels)
|
||||
target_include_directories(f32-random-test PRIVATE source/include)
|
||||
add_test(NAME f32-random-test COMMAND f32-random-test)
|
||||
|
||||
add_executable(mf4-f32-convert-test test/mf4-f32-convert.cc)
|
||||
target_link_libraries(mf4-f32-convert-test PRIVATE GTest::gtest_main metal-kernels)
|
||||
target_include_directories(mf4-f32-convert-test PRIVATE source/include)
|
||||
add_test(NAME mf4-f32-convert-test COMMAND mf4-f32-convert-test)
|
||||
|
||||
add_executable(bf16-f32-embeddings-test test/bf16-f32-embeddings.cc)
|
||||
target_link_libraries(bf16-f32-embeddings-test PRIVATE GTest::gtest_main metal-kernels)
|
||||
target_include_directories(bf16-f32-embeddings-test PRIVATE source/include)
|
||||
add_test(NAME bf16-f32-embeddings-test COMMAND bf16-f32-embeddings-test)
|
||||
|
||||
add_executable(f32-bf16w-rmsnorm-test test/f32-bf16w-rmsnorm.cc)
|
||||
target_link_libraries(f32-bf16w-rmsnorm-test PRIVATE GTest::gtest_main metal-kernels)
|
||||
target_include_directories(f32-bf16w-rmsnorm-test PRIVATE source/include)
|
||||
add_test(NAME f32-bf16w-rmsnorm-test COMMAND f32-bf16w-rmsnorm-test)
|
||||
|
||||
add_executable(f32-bf16w-matmul-test test/f32-bf16w-matmul.cc)
|
||||
target_link_libraries(f32-bf16w-matmul-test PRIVATE GTest::gtest_main metal-kernels)
|
||||
target_include_directories(f32-bf16w-matmul-test PRIVATE source/include)
|
||||
add_test(NAME f32-bf16w-matmul-test COMMAND f32-bf16w-matmul-test)
|
||||
|
||||
add_executable(f32-rope-test test/f32-rope.cc)
|
||||
target_link_libraries(f32-rope-test PRIVATE GTest::gtest_main metal-kernels)
|
||||
target_include_directories(f32-rope-test PRIVATE source/include)
|
||||
add_test(NAME f32-rope-test COMMAND f32-rope-test)
|
||||
|
||||
# --- [ Benchmarks
|
||||
include(FetchContent)
|
||||
set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable self-tests in Google Benchmark" FORCE)
|
||||
set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable installation of Google Benchmark" FORCE)
|
||||
FetchContent_Declare(
|
||||
benchmark
|
||||
URL https://github.com/google/benchmark/archive/refs/tags/v1.9.4.zip
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP OFF
|
||||
)
|
||||
FetchContent_MakeAvailable(benchmark)
|
||||
|
||||
add_executable(f32-random-bench benchmark/f32-random.cc)
|
||||
target_link_libraries(f32-random-bench PRIVATE benchmark::benchmark metal-kernels)
|
||||
target_include_directories(f32-random-bench PRIVATE source/include)
|
||||
|
||||
add_executable(u32-random-bench benchmark/u32-random.cc)
|
||||
target_link_libraries(u32-random-bench PRIVATE benchmark::benchmark metal-kernels)
|
||||
target_include_directories(u32-random-bench PRIVATE source/include)
|
||||
|
||||
add_executable(mf4-f32-convert-bench benchmark/mf4-f32-convert.cc)
|
||||
target_link_libraries(mf4-f32-convert-bench PRIVATE benchmark::benchmark metal-kernels)
|
||||
target_include_directories(mf4-f32-convert-bench PRIVATE source/include)
|
||||
|
||||
add_executable(f32-bf16w-rmsnorm-bench benchmark/f32-bf16w-rmsnorm.cc)
|
||||
target_link_libraries(f32-bf16w-rmsnorm-bench PRIVATE benchmark::benchmark metal-kernels)
|
||||
target_include_directories(f32-bf16w-rmsnorm-bench PRIVATE source/include)
|
||||
|
||||
# --- [ Python extension ] -----------------------------------------------
|
||||
find_package(pybind11 CONFIG REQUIRED) # provides pybind11_add_module
|
||||
|
||||
pybind11_add_module(_metal
|
||||
python/module.c
|
||||
python/context.c
|
||||
python/model.c
|
||||
python/tokenizer.c
|
||||
)
|
||||
set_target_properties(_metal PROPERTIES PREFIX "")
|
||||
|
||||
target_link_libraries(_metal PRIVATE gptoss)
|
||||
add_dependencies(_metal build_metallib)
|
||||
target_link_options(_metal PRIVATE
|
||||
LINKER:-sectcreate,__METAL,__shaders,${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
||||
)
|
||||
add_custom_command(TARGET _metal POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
||||
$<TARGET_FILE_DIR:_metal>)
|
||||
|
||||
# 1️⃣ install the extension module into the Python package
|
||||
install(TARGETS _metal LIBRARY DESTINATION gpt_oss/metal)
|
||||
|
||||
# 2️⃣ make sure the Metal shader archive travels with it
|
||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
||||
DESTINATION gpt_oss/metal)
|
||||
# ------------------------------------------------------------------------
|
||||
6
gpt_oss/metal/__init__.py
Normal file
6
gpt_oss/metal/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from importlib import import_module as _im
|
||||
|
||||
# Load the compiled extension (gpt_oss.metal._metal)
|
||||
_ext = _im(f"{__name__}._metal")
|
||||
globals().update({k: v for k, v in _ext.__dict__.items() if not k.startswith("_")})
|
||||
del _im, _ext
|
||||
95
gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc
Normal file
95
gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc
Normal file
@@ -0,0 +1,95 @@
|
||||
#include <gpt-oss.h>
|
||||
#include <internal/datatype.h>
|
||||
#include <internal/metal.hpp>
|
||||
#include <internal/metal-kernels.h>
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include <benchmark/benchmark.h>
|
||||
|
||||
using gptoss::Check;
|
||||
using namespace gptoss::metal;
|
||||
|
||||
constexpr float kEpsilon = 1.0e-5f;
|
||||
constexpr uint64_t kSeed = UINT64_C(1019827666124465388);
|
||||
|
||||
static void f32_bf16w_rnsnorm(benchmark::State& state) {
|
||||
const size_t num_tokens = 1;
|
||||
const size_t num_channels = state.range(0);
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
Library library{device};
|
||||
Function f32_fill_random_fn{library, "gptoss_f32_fill_random"};
|
||||
Function bf16_fill_random_fn{library, "gptoss_bf16_fill_random"};
|
||||
Function f32_bf16w_rmsnorm_fn{library, "gptoss_f32_bf16w_rmsnorm"};
|
||||
Buffer input_buffer{device, num_tokens * num_channels * sizeof(float)};
|
||||
Buffer weight_buffer{device, num_channels * sizeof(gptoss_bfloat16)};
|
||||
Buffer output_buffer{device, num_tokens * num_channels * sizeof(float)};
|
||||
|
||||
{
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
|
||||
size_t offset = 0;
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
||||
command_buffer.handle(),
|
||||
f32_fill_random_fn.handle(),
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/10,
|
||||
/*output_buffer=*/input_buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_channels, kSeed, offset, /*min=*/-1.0f, /*max=*/1.0),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
|
||||
offset += num_channels;
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
|
||||
command_buffer.handle(),
|
||||
bf16_fill_random_fn.handle(),
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/10,
|
||||
/*output_buffer=*/weight_buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_channels, kSeed, offset, /*min=*/-1.0f, /*max=*/1.0),
|
||||
"gptoss_metal_command_buffer_encode_launch_bf16_fill_random");
|
||||
offset += num_channels;
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
}
|
||||
|
||||
for (auto _ : state) {
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
||||
command_buffer.handle(),
|
||||
f32_bf16w_rmsnorm_fn.handle(),
|
||||
input_buffer.handle(),
|
||||
/*input_offset=*/0,
|
||||
weight_buffer.handle(),
|
||||
/*weight_offset=*/0,
|
||||
output_buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_tokens,
|
||||
num_channels,
|
||||
kEpsilon),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm");
|
||||
|
||||
command_buffer.commit();
|
||||
const double elapsed_seconds = command_buffer.wait_completion();
|
||||
state.SetIterationTime(elapsed_seconds);
|
||||
}
|
||||
|
||||
const size_t num_elements = num_tokens * num_channels;
|
||||
state.counters["elements"] =
|
||||
benchmark::Counter(state.iterations() * num_elements,
|
||||
benchmark::Counter::kIsRate);
|
||||
|
||||
const int64_t bytes_per_iteration = input_buffer.size() + weight_buffer.size() + output_buffer.size();
|
||||
state.counters["bytes"] =
|
||||
benchmark::Counter(state.iterations() * bytes_per_iteration,
|
||||
benchmark::Counter::kIsRate);
|
||||
}
|
||||
|
||||
BENCHMARK(f32_bf16w_rnsnorm)->Arg(2880)->UseManualTime()->Unit(benchmark::kMicrosecond);
|
||||
|
||||
BENCHMARK_MAIN();
|
||||
55
gpt_oss/metal/benchmark/f32-random.cc
Normal file
55
gpt_oss/metal/benchmark/f32-random.cc
Normal file
@@ -0,0 +1,55 @@
|
||||
#include <gpt-oss.h>
|
||||
#include <internal/metal.hpp>
|
||||
#include <internal/metal-kernels.h>
|
||||
|
||||
#include <benchmark/benchmark.h>
|
||||
|
||||
using gptoss::Check;
|
||||
using namespace gptoss::metal;
|
||||
|
||||
static void f32_fill_random(benchmark::State& state) {
|
||||
const size_t numel = state.range(0);
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
Library library{device};
|
||||
Function f32_fill_random_fn{library, "gptoss_f32_fill_random"};
|
||||
Buffer buffer{device, numel * sizeof(float)};
|
||||
|
||||
constexpr uint64_t seed = UINT64_C(1019827666124465388);
|
||||
constexpr uint64_t offset = UINT64_C(12345678901234567890);
|
||||
const float min = -1.0f;
|
||||
const float max = 7.0f;
|
||||
for (auto _ : state) {
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
||||
command_buffer.handle(),
|
||||
f32_fill_random_fn.handle(),
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/120,
|
||||
/*output_buffer=*/buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
numel, seed, offset, min, max),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
|
||||
|
||||
command_buffer.commit();
|
||||
const double elapsed_seconds = command_buffer.wait_completion();
|
||||
state.SetIterationTime(elapsed_seconds);
|
||||
}
|
||||
|
||||
const int64_t elements_per_iteration = numel;
|
||||
state.counters["elements"] =
|
||||
benchmark::Counter(state.iterations() * elements_per_iteration,
|
||||
benchmark::Counter::kIsRate);
|
||||
|
||||
const int64_t bytes_per_iteration = numel * sizeof(float);
|
||||
state.counters["bytes"] =
|
||||
benchmark::Counter(state.iterations() * bytes_per_iteration,
|
||||
benchmark::Counter::kIsRate);
|
||||
}
|
||||
|
||||
constexpr int64_t giga = INT64_C(1073741824);
|
||||
BENCHMARK(f32_fill_random)->Arg(2 * giga)->UseManualTime()->Unit(benchmark::kMicrosecond);
|
||||
|
||||
BENCHMARK_MAIN();
|
||||
65
gpt_oss/metal/benchmark/mf4-f32-convert.cc
Normal file
65
gpt_oss/metal/benchmark/mf4-f32-convert.cc
Normal file
@@ -0,0 +1,65 @@
|
||||
#include <gpt-oss.h>
|
||||
#include <internal/datatype.h>
|
||||
#include <internal/metal.hpp>
|
||||
#include <internal/metal-kernels.h>
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include <benchmark/benchmark.h>
|
||||
|
||||
using gptoss::Check;
|
||||
using namespace gptoss::metal;
|
||||
|
||||
static void mf4_f32_convert(benchmark::State& state) {
|
||||
const size_t num_blocks = state.range(0);
|
||||
const size_t num_elements = num_blocks * 32;
|
||||
const size_t num_bytes = num_elements / 2;
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
Library library{device};
|
||||
Function mf4_f32_convert_fn{library, "gptoss_mf4_f32_convert"};
|
||||
Buffer block_buffer{device, num_bytes};
|
||||
Buffer scale_buffer{device, num_blocks * sizeof(gptoss_float8ue8m0)};
|
||||
Buffer output_buffer{device, num_elements * sizeof(float)};
|
||||
|
||||
std::memset(block_buffer.ptr(), 0x91, num_bytes); // force subnormals
|
||||
std::memset(scale_buffer.ptr(), 128, num_blocks * sizeof(uint8_t)); // scale = 2.0
|
||||
|
||||
for (auto _ : state) {
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
|
||||
command_buffer.handle(),
|
||||
mf4_f32_convert_fn.handle(),
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/120,
|
||||
block_buffer.handle(),
|
||||
scale_buffer.handle(),
|
||||
output_buffer.handle(),
|
||||
num_elements),
|
||||
"gptoss_metal_command_buffer_encode_launch_mf4_f32_convert");
|
||||
|
||||
command_buffer.commit();
|
||||
const double elapsed_seconds = command_buffer.wait_completion();
|
||||
state.SetIterationTime(elapsed_seconds);
|
||||
}
|
||||
|
||||
state.counters["blocks"] =
|
||||
benchmark::Counter(state.iterations() * num_blocks,
|
||||
benchmark::Counter::kIsRate);
|
||||
|
||||
state.counters["elements"] =
|
||||
benchmark::Counter(state.iterations() * num_elements,
|
||||
benchmark::Counter::kIsRate);
|
||||
|
||||
const int64_t bytes_per_iteration = num_bytes + num_blocks + num_elements * sizeof(float);
|
||||
state.counters["bytes"] =
|
||||
benchmark::Counter(state.iterations() * bytes_per_iteration,
|
||||
benchmark::Counter::kIsRate);
|
||||
}
|
||||
|
||||
constexpr int64_t mega = INT64_C(1048576);
|
||||
BENCHMARK(mf4_f32_convert)->Arg(256 * mega)->UseManualTime()->Unit(benchmark::kMicrosecond);
|
||||
|
||||
BENCHMARK_MAIN();
|
||||
53
gpt_oss/metal/benchmark/u32-random.cc
Normal file
53
gpt_oss/metal/benchmark/u32-random.cc
Normal file
@@ -0,0 +1,53 @@
|
||||
#include <gpt-oss.h>
|
||||
#include <internal/metal.hpp>
|
||||
#include <internal/metal-kernels.h>
|
||||
|
||||
#include <benchmark/benchmark.h>
|
||||
|
||||
using gptoss::Check;
|
||||
using namespace gptoss::metal;
|
||||
|
||||
static void u32_fill_random(benchmark::State& state) {
|
||||
const size_t numel = state.range(0);
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
Library library{device};
|
||||
Function u32_fill_random_fn{library, "gptoss_u32_fill_random"};
|
||||
Buffer buffer{device, numel * sizeof(float)};
|
||||
|
||||
constexpr uint64_t seed = UINT64_C(1019827666124465388);
|
||||
constexpr uint64_t offset = UINT64_C(12345678901234567890);
|
||||
for (auto _ : state) {
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_u32_fill_random(
|
||||
command_buffer.handle(),
|
||||
u32_fill_random_fn.handle(),
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/120,
|
||||
/*output_buffer=*/buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
numel, seed, offset),
|
||||
"gptoss_metal_command_buffer_encode_launch_u32_fill_random");
|
||||
|
||||
command_buffer.commit();
|
||||
const double elapsed_seconds = command_buffer.wait_completion();
|
||||
state.SetIterationTime(elapsed_seconds);
|
||||
}
|
||||
|
||||
const int64_t elements_per_iteration = numel;
|
||||
state.counters["elements"] =
|
||||
benchmark::Counter(state.iterations() * elements_per_iteration,
|
||||
benchmark::Counter::kIsRate);
|
||||
|
||||
const int64_t bytes_per_iteration = numel * sizeof(float);
|
||||
state.counters["bytes"] =
|
||||
benchmark::Counter(state.iterations() * bytes_per_iteration,
|
||||
benchmark::Counter::kIsRate);
|
||||
}
|
||||
|
||||
constexpr int64_t giga = INT64_C(1073741824);
|
||||
BENCHMARK(u32_fill_random)->Arg(2 * giga)->UseManualTime()->Unit(benchmark::kMicrosecond);
|
||||
|
||||
BENCHMARK_MAIN();
|
||||
104
gpt_oss/metal/examples/chat.py
Executable file
104
gpt_oss/metal/examples/chat.py
Executable file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from datetime import date
|
||||
from gpt_oss.metal import Context, Model
|
||||
|
||||
|
||||
DEFAULT_PROMPT = f"""You are ChatGPT, a large language model trained by OpenAI.
|
||||
Knowledge cutoff: 2024-06
|
||||
Current date: {date.today().isoformat()}
|
||||
|
||||
reasoning effort high
|
||||
|
||||
# Valid channels: analysis, final. Channel must be included for every message."""
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Chat with gpt-oss", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("model", metavar="PATH", type=str, help="Path to gpt-oss model in Metal inference format")
|
||||
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="System prompt")
|
||||
parser.add_argument(
|
||||
"--context-length", type=int, default=0, help="The maximum context length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=1.0, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=0, help="Sampling seed"
|
||||
)
|
||||
|
||||
|
||||
GREY = "\33[90m"
|
||||
BOLD = "\33[1m"
|
||||
RESET = "\33[0m"
|
||||
|
||||
|
||||
def main(args):
|
||||
options = parser.parse_args(args)
|
||||
model = Model(options.model)
|
||||
tokenizer = model.tokenizer
|
||||
start_token = tokenizer.encode_special_token("<|start|>")
|
||||
message_token = tokenizer.encode_special_token("<|message|>")
|
||||
end_token = tokenizer.encode_special_token("<|end|>")
|
||||
return_token = tokenizer.encode_special_token("<|return|>")
|
||||
channel_token = tokenizer.encode_special_token("<|channel|>")
|
||||
|
||||
context = Context(model, context_length=options.context_length)
|
||||
context.append(start_token)
|
||||
context.append("system")
|
||||
context.append(message_token)
|
||||
context.append(options.prompt)
|
||||
context.append(end_token)
|
||||
|
||||
while True:
|
||||
context.append(start_token)
|
||||
context.append("user")
|
||||
context.append(message_token)
|
||||
message = input(f"{BOLD}User:{RESET} ").rstrip()
|
||||
context.append(message)
|
||||
context.append(end_token)
|
||||
print(f"{BOLD}Assistant:{RESET} {GREY}", end="", flush=True)
|
||||
context.append(start_token)
|
||||
context.append("assistant")
|
||||
context.append(channel_token)
|
||||
|
||||
inside_start_block = True
|
||||
inside_channel_block = True
|
||||
role = "assistant"
|
||||
channel = ""
|
||||
while True:
|
||||
token = context.sample(
|
||||
temperature=options.temperature,
|
||||
seed=options.seed,
|
||||
)
|
||||
context.append(token)
|
||||
if token == return_token:
|
||||
print(flush=True)
|
||||
break
|
||||
elif token == start_token:
|
||||
inside_start_block = True
|
||||
role = ""
|
||||
channel = ""
|
||||
elif token == message_token:
|
||||
inside_start_block = False
|
||||
inside_channel_block = False
|
||||
if channel == "analysis":
|
||||
print(f"{GREY}", end="", flush=True)
|
||||
elif token == end_token:
|
||||
print(f"{RESET}", flush=True)
|
||||
elif token == channel_token:
|
||||
inside_channel_block = True
|
||||
elif token < tokenizer.num_text_tokens:
|
||||
if inside_channel_block:
|
||||
channel += str(tokenizer.decode(token), encoding="utf-8")
|
||||
elif inside_start_block:
|
||||
role += str(tokenizer.decode(token), encoding="utf-8")
|
||||
else:
|
||||
sys.stdout.buffer.write(tokenizer.decode(token))
|
||||
sys.stdout.buffer.flush()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
||||
35
gpt_oss/metal/examples/generate.py
Normal file
35
gpt_oss/metal/examples/generate.py
Normal file
@@ -0,0 +1,35 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from datetime import date
|
||||
from gpt_oss import Context, Model
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('model', metavar='PATH', type=str, help='Path to gpt-oss checkpoint')
|
||||
parser.add_argument('-p', '--prompt', type=str, required=True, help='Prompt')
|
||||
parser.add_argument('-l', '--limit', type=int, default=100, help='Number of tokens to generate')
|
||||
parser.add_argument('--context-length', type=int, default=0, help='The maximum context length')
|
||||
|
||||
|
||||
def main(args):
|
||||
options = parser.parse_args(args)
|
||||
model = Model(options.model)
|
||||
|
||||
context = Context(model, context_length=options.context_length)
|
||||
context.append(options.prompt)
|
||||
print(context.tokens)
|
||||
prompt_tokens = context.num_tokens
|
||||
|
||||
tokenizer = model.tokenizer
|
||||
|
||||
while context.num_tokens - prompt_tokens < options.limit:
|
||||
token = context.sample()
|
||||
context.append(token)
|
||||
print(str(tokenizer.decode(token), encoding="utf-8"), end='', flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
5
gpt_oss/metal/include/gpt-oss.h
Normal file
5
gpt_oss/metal/include/gpt-oss.h
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <gpt-oss/macros.h>
|
||||
#include <gpt-oss/types.h>
|
||||
#include <gpt-oss/functions.h>
|
||||
397
gpt_oss/metal/include/gpt-oss/functions.h
Normal file
397
gpt_oss/metal/include/gpt-oss/functions.h
Normal file
@@ -0,0 +1,397 @@
|
||||
#pragma once
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <gpt-oss/macros.h>
|
||||
#include <gpt-oss/types.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Creates a Model object from a file in the filesystem.
|
||||
*
|
||||
* @param path Path to the file containing the model in GPT-OSS format.
|
||||
* @param model_out Pointer to the Model object that will be created. Must be released with gptoss_release_model.
|
||||
*
|
||||
* On success, returns gptoss_status_success and saves a pointer to the created Model in the model_out argument.
|
||||
* On failure, returns an error code and stores null pointer in the model_out argument.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
|
||||
const char* path,
|
||||
gptoss_model_t* model_out);
|
||||
|
||||
/*
|
||||
* Query the Tokenizer object associated with the Model.
|
||||
*
|
||||
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
|
||||
* @param tokenizer_out Pointer to the variable where the Tokenizer reference will be stored.
|
||||
*
|
||||
* On success, returns gptoss_status_success and stores reference to the Tokenizer object in the tokenizer_out argument.
|
||||
* On failure, returns an error code and stores NULL in the tokenizer_out argument.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_model_get_tokenizer(
|
||||
gptoss_model_t model,
|
||||
gptoss_tokenizer_t* tokenizer_out);
|
||||
|
||||
/*
|
||||
* Query the maximum context length supported by the Model.
|
||||
*
|
||||
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
|
||||
* @param max_context_length_out Pointer to the variable where the maximum context length will be stored.
|
||||
*
|
||||
* On success, returns gptoss_status_success and stores maximum context length in the max_context_length_out argument.
|
||||
* On failure, returns an error code and leaves the value specified by max_context_length_out unchanged.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_model_get_max_context_length(
|
||||
gptoss_model_t model,
|
||||
size_t* max_context_length_out);
|
||||
|
||||
/*
|
||||
* Increments a Model object's reference count.
|
||||
*
|
||||
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_model_retain(
|
||||
gptoss_model_t model);
|
||||
|
||||
/*
|
||||
* Decrements a Model object's reference count and possibly release associated resources.
|
||||
*
|
||||
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_model_release(
|
||||
gptoss_model_t model);
|
||||
|
||||
/*
|
||||
* Query the token ID for a special token in the Tokenizer vocabulary.
|
||||
*
|
||||
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
|
||||
* @param token_type Type of the special token to query an ID for.
|
||||
* @param token_id_out Pointer to the variable where the token ID will be stored.
|
||||
*
|
||||
* On success, returns gptoss_status_success and stores the token ID in the token_id_out argument.
|
||||
* On failure, returns an error code and leaves the value specified by token_id_out unchanged.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_special_token_id(
|
||||
gptoss_tokenizer_t tokenizer,
|
||||
enum gptoss_special_token token_type,
|
||||
uint32_t* token_id_out);
|
||||
|
||||
/*
|
||||
* Query the number of text tokens in the Tokenizer vocabulary.
|
||||
*
|
||||
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
|
||||
* @param num_text_tokens_out Pointer to the variable where the number of text tokens will be stored.
|
||||
*
|
||||
* On success, returns gptoss_status_success and stores the number of text tokens in the num_text_tokens_out argument.
|
||||
* On failure, returns an error code and leaves the value specified by num_text_tokens_out unchanged.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_text_tokens(
|
||||
gptoss_tokenizer_t tokenizer,
|
||||
uint32_t* num_text_tokens_out);
|
||||
|
||||
/*
|
||||
* Query the number of special tokens in the Tokenizer vocabulary.
|
||||
*
|
||||
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
|
||||
* @param num_special_tokens_out Pointer to the variable where the number of special tokens will be stored.
|
||||
*
|
||||
* On success, returns gptoss_status_success and stores the number of text tokens in the num_special_tokens_out argument.
|
||||
* On failure, returns an error code and leaves the value specified by num_special_tokens_out unchanged.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_special_tokens(
|
||||
gptoss_tokenizer_t tokenizer,
|
||||
uint32_t* num_special_tokens_out);
|
||||
|
||||
/*
|
||||
* Query the total number of tokens in the Tokenizer vocabulary.
|
||||
*
|
||||
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
|
||||
* @param num_tokens_out Pointer to the variable where the total number of tokens will be stored.
|
||||
*
|
||||
* On success, returns gptoss_status_success and stores the total number of tokens in the num_special_tokens_out argument.
|
||||
* On failure, returns an error code and leaves the value specified by num_special_tokens_out unchanged.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_tokens(
|
||||
gptoss_tokenizer_t tokenizer,
|
||||
uint32_t* num_tokens_out);
|
||||
|
||||
/*
|
||||
* Convert a text token ID to byte representation.
|
||||
*
|
||||
* @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer. The lifetime of the returned
|
||||
* byte representation would match the lifetime of this Tokenizer object.
|
||||
* @param token_ptr_out Pointer to the variable where the pointer to the byte representation of the token will be
|
||||
* stored.
|
||||
* @param token_size_out Pointer to the variable where the size of the byte representation of the token will be stored.
|
||||
*
|
||||
* On success, returns gptoss_status_success and stores pointer and size of the byte representation of the token in the
|
||||
* token_ptr_out and token_size_out arguments.
|
||||
* On failure, returns an error code and leaves the values specified in token_ptr_out and token_size_out unchanged.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_decode(
|
||||
gptoss_tokenizer_t tokenizer,
|
||||
uint32_t token_id,
|
||||
const void** token_ptr_out,
|
||||
size_t* token_size_out);
|
||||
|
||||
/*
|
||||
* Increments a Tokenizer object's reference count.
|
||||
*
|
||||
* @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_retain(
|
||||
gptoss_tokenizer_t tokenizer);
|
||||
|
||||
/*
|
||||
* Decrements a Tokenizer object's reference count and possibly release associated resources.
|
||||
*
|
||||
* @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_release(
|
||||
gptoss_tokenizer_t tokenizer);
|
||||
|
||||
/*
|
||||
* Creates a Context object for use with the particular Model object.
|
||||
*
|
||||
* @param model Model object to create a context for.
|
||||
* @param context_length Maximum number of tokens in the context.
|
||||
* Specify 0 to use the maximum context length supported by the model.
|
||||
* @param batch_size Maximum number of tokens that can be processed in a single batch.
|
||||
* Larger values may improve performance, but require more memory.
|
||||
* @param context_out Pointer to the Context object that will be created.
|
||||
* Must be released with gptoss_release_context.
|
||||
*
|
||||
* On success, returns gptoss_status_success and saves a pointer to the created Context in the context_out argument.
|
||||
* On failure, returns an error code and stores null pointer in the context_out argument.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_create(
|
||||
gptoss_model_t model,
|
||||
size_t context_length,
|
||||
gptoss_context_t* context_out);
|
||||
|
||||
/*
|
||||
* Query the current number of tokens cached in the Context.
|
||||
*
|
||||
* @param context Pointer to the Context object created by gptoss_context_create.
|
||||
* @param num_tokens_out Pointer to the variable where the current number of cached tokens will be stored.
|
||||
*
|
||||
* On success, returns gptoss_status_success and stores current number of cached tokens in the num_tokens_out argument.
|
||||
* On failure, returns an error code and leaves the value specified by num_tokens_out unchanged.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(
|
||||
gptoss_context_t context,
|
||||
size_t* num_tokens_out);
|
||||
|
||||
/*
|
||||
* Query the maximum number of tokens cached in the Context.
|
||||
*
|
||||
* @param context Pointer to the Context object created by gptoss_context_create.
|
||||
* @param max_tokens_out Pointer to the variable where the maximum number of cached tokens will be stored.
|
||||
*
|
||||
* On success, returns gptoss_status_success and stores maximum number of cached tokens in the max_tokens_out argument.
|
||||
* On failure, returns an error code and leaves the value specified by max_tokens_out unchanged.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(
|
||||
gptoss_context_t context,
|
||||
size_t* max_tokens_out);
|
||||
|
||||
/*
|
||||
* Query the list of token IDs cached in the Context.
|
||||
*
|
||||
* @param context Pointer to the Context object created by gptoss_context_create.
|
||||
* @param tokens_out Pointer to the array where up to max_tokens_out of cached tokens will be stored.
|
||||
* @param max_tokens Maximum capacity of the buffer specified by tokens_out.
|
||||
* @param num_tokens_out Pointer to the variable where the actual number of cached tokens will be stored.
|
||||
* This value can exceed max_tokens if the buffer capacity is insufficient.
|
||||
*
|
||||
* On success, returns gptoss_status_success and stores cached token IDs in the tokens_out argument and the number of
|
||||
* cached tokens in the num_tokens_out argument.
|
||||
* On failure, returns an error code and leaves the values specified by tokend_out and num_tokens_out unchanged.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
|
||||
gptoss_context_t context,
|
||||
uint32_t* tokens_out,
|
||||
size_t max_tokens,
|
||||
size_t* num_tokens_out);
|
||||
|
||||
/*
|
||||
* Tokenize and appends a character string to the Context object.
|
||||
*
|
||||
* @param context Context object created by gptoss_context_create.
|
||||
* @param text Pointer to the character string to tokenizer and append.
|
||||
* @param text_length Length of the string, in chars.
|
||||
* @param num_tokens_out Optional pointer to the variable where the number of appended tokens will be stored. Ignored if a null pointer is provided.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
|
||||
gptoss_context_t context,
|
||||
const char* text,
|
||||
size_t text_length,
|
||||
size_t* num_tokens_out);
|
||||
|
||||
/*
|
||||
* Appends a list of tokens to the context.
|
||||
*
|
||||
* @param context Context object created by gptoss_context_create.
|
||||
* @param num_tokens Number of tokens to be appended.
|
||||
* @param tokens Pointer to the array of tokens to be appended.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
|
||||
gptoss_context_t context,
|
||||
size_t num_tokens,
|
||||
const uint32_t* tokens);
|
||||
|
||||
/*
|
||||
* Resets the context, clearing its state.
|
||||
*
|
||||
* @param context Context object created by gptoss_context_create.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_reset(
|
||||
gptoss_context_t context);
|
||||
|
||||
/*
|
||||
* Pre-process the tokens in the Context and generate probability distrubution over the next token.
|
||||
*
|
||||
* @param context Context object created by gptoss_context_create.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_process(
|
||||
gptoss_context_t context);
|
||||
|
||||
/*
|
||||
* Generate a token probability distribution over the next token conditioned on the Context.
|
||||
*
|
||||
* @param context Context object created by gptoss_context_create.
|
||||
* @param temperature Sampling temperature. Must be non-negative.
|
||||
* @param seed Random number generator seed to use for sampling.
|
||||
* @param token_out Pointer to the variable where the token ID will be stored.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_sample(
|
||||
gptoss_context_t context,
|
||||
float temperature,
|
||||
uint64_t seed,
|
||||
uint32_t* token_out);
|
||||
|
||||
/*
|
||||
* Increments a Context object's reference count.
|
||||
*
|
||||
* @param context Pointer to the Context object created by gptoss_create_context.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_retain(
|
||||
gptoss_context_t context);
|
||||
|
||||
/*
|
||||
* Decrements a Context object's reference count and possibly release associated resources.
|
||||
*
|
||||
* @param context Pointer to the Context object created by gptoss_create_context.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_release(
|
||||
gptoss_context_t context);
|
||||
|
||||
/*
|
||||
* Creates a Sampler object.
|
||||
*
|
||||
* @param sampler_out Pointer to the Sampler object that will be created.
|
||||
* Must be released with gptoss_sampler_release.
|
||||
*
|
||||
* On success, returns gptoss_status_success and saves a pointer to the created Sampler in the sampler_out argument.
|
||||
* On failure, returns an error code and stores a null pointer in the sampler_out argument.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_sampler_create(
|
||||
gptoss_sampler_t* sampler_out);
|
||||
|
||||
/*
|
||||
* Sets the sampling temperature for the Sampler.
|
||||
*
|
||||
* @param sampler Sampler object created by gptoss_sampler_create.
|
||||
* @param temperature Temperature value to be set. Must be in the [0.0, 1.0] range.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_temperature(
|
||||
gptoss_sampler_t sampler,
|
||||
float temperature);
|
||||
|
||||
/*
|
||||
* Sets the Top-P nucleus sampling parameter for the Sampler.
|
||||
*
|
||||
* @param sampler Sampler object created by gptoss_sampler_create.
|
||||
* @param top_p Top-P value to be set. Must be in the (0.0, 1.0] range.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_top_p(
|
||||
gptoss_sampler_t sampler,
|
||||
float top_p);
|
||||
|
||||
/*
|
||||
* Sets the presence penalty for the Sampler.
|
||||
*
|
||||
* @param sampler Sampler object created by gptoss_sampler_create.
|
||||
* @param presence_penalty Presence penalty value to be set. Must be in the [-2.0, 2.0] range.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_presence_penalty(
|
||||
gptoss_sampler_t sampler,
|
||||
float presence_penalty);
|
||||
|
||||
/*
|
||||
* Sets the frequency penalty for the Sampler.
|
||||
*
|
||||
* @param sampler Sampler object created by gptoss_sampler_create.
|
||||
* @param frequency_penalty Frequency penalty value to be set. Must be in the [-2.0, 2.0] range.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_frequency_penalty(
|
||||
gptoss_sampler_t sampler,
|
||||
float frequency_penalty);
|
||||
|
||||
/*
|
||||
* Increments a Sampler object's reference count.
|
||||
*
|
||||
* @param sampler Pointer to the Sampler object created by gptoss_sampler_create.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_sampler_retain(
|
||||
gptoss_sampler_t sampler);
|
||||
|
||||
/*
|
||||
* Decrements a Sampler object's reference count and possibly releases associated resources.
|
||||
*
|
||||
* @param sampler Pointer to the Sampler object created by gptoss_sampler_create.
|
||||
*
|
||||
* On success, returns gptoss_status_success, otherwise returns an error code.
|
||||
*/
|
||||
enum gptoss_status GPTOSS_ABI gptoss_sampler_release(
|
||||
gptoss_sampler_t sampler);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
5
gpt_oss/metal/include/gpt-oss/macros.h
Normal file
5
gpt_oss/metal/include/gpt-oss/macros.h
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef GPTOSS_ABI
|
||||
#define GPTOSS_ABI
|
||||
#endif // GPTOSS_ABI
|
||||
62
gpt_oss/metal/include/gpt-oss/types.h
Normal file
62
gpt_oss/metal/include/gpt-oss/types.h
Normal file
@@ -0,0 +1,62 @@
|
||||
#pragma once
|
||||
|
||||
/*
|
||||
* Status codes returned by GPT-OSS API functions.
|
||||
*/
|
||||
enum gptoss_status {
|
||||
gptoss_status_success = 0,
|
||||
gptoss_status_invalid_argument = 1,
|
||||
gptoss_status_unsupported_argument = 2,
|
||||
gptoss_status_invalid_state = 3,
|
||||
gptoss_status_io_error = 4,
|
||||
gptoss_status_insufficient_memory = 5,
|
||||
gptoss_status_insufficient_resources = 6,
|
||||
gptoss_status_unsupported_system = 7,
|
||||
gptoss_status_context_overflow = 8,
|
||||
};
|
||||
|
||||
enum gptoss_special_token {
|
||||
gptoss_special_token_invalid = 0,
|
||||
gptoss_special_token_return = 1,
|
||||
gptoss_special_token_start = 2,
|
||||
gptoss_special_token_message = 3,
|
||||
gptoss_special_token_end = 4,
|
||||
gptoss_special_token_refusal = 5,
|
||||
gptoss_special_token_constrain = 6,
|
||||
gptoss_special_token_channel = 7,
|
||||
gptoss_special_token_call = 8,
|
||||
gptoss_special_token_untrusted = 9,
|
||||
gptoss_special_token_end_untrusted = 10,
|
||||
gptoss_special_token_max,
|
||||
};
|
||||
|
||||
/*
|
||||
* Model object is an opaque container comprised of:
|
||||
* - Weights
|
||||
* - Temporary buffers required to run the model
|
||||
* - Any other resources requires to run the model
|
||||
*/
|
||||
typedef struct gptoss_model* gptoss_model_t;
|
||||
|
||||
typedef struct gptoss_tokenizer* gptoss_tokenizer_t;
|
||||
|
||||
/*
|
||||
* Context is an opaque container comprised of:
|
||||
* - Input tokens
|
||||
* - Distribution over the output tokens
|
||||
* - KV cache
|
||||
*
|
||||
* Multiple contexts can be created and used with the same model.
|
||||
*/
|
||||
typedef struct gptoss_context* gptoss_context_t;
|
||||
|
||||
/*
|
||||
* Sampler is an opaque container for sampling parameters:
|
||||
* - Temperature
|
||||
* - Top-p (nucleus sampling)
|
||||
* - Frequency penalty
|
||||
* - Presence penalty
|
||||
*
|
||||
* Multiple samplers can be created and used with the same context.
|
||||
*/
|
||||
typedef struct gptoss_sampler* gptoss_sampler_t;
|
||||
265
gpt_oss/metal/python/context.c
Normal file
265
gpt_oss/metal/python/context.c
Normal file
@@ -0,0 +1,265 @@
|
||||
#include <Python.h>
|
||||
|
||||
#include <gpt-oss.h>
|
||||
|
||||
#include "module.h"
|
||||
|
||||
|
||||
static int PyGPTOSSContext_init(PyGPTOSSContext* self, PyObject* args, PyObject* kwargs) {
|
||||
static char *kwlist[] = {"model", "context_length", NULL};
|
||||
PyObject* model = NULL;
|
||||
Py_ssize_t context_length = 0; // Default to 0 if None
|
||||
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|$i", kwlist,
|
||||
&model, &context_length)) {
|
||||
return -1;
|
||||
}
|
||||
if (!PyObject_TypeCheck(model, &PyGPTOSSModel_Type)) {
|
||||
PyErr_SetString(PyExc_TypeError, "model must be an gptoss.Model object");
|
||||
return -1;
|
||||
}
|
||||
if (context_length < 0) {
|
||||
PyErr_SetString(PyExc_ValueError, "context_length must be a positive integer");
|
||||
return -1;
|
||||
}
|
||||
|
||||
enum gptoss_status status = gptoss_context_create(
|
||||
((const PyGPTOSSModel*) model)->handle,
|
||||
(size_t) context_length,
|
||||
&self->handle);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
goto error;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
error:
|
||||
gptoss_context_release(self->handle);
|
||||
self->handle = NULL;
|
||||
return -1;
|
||||
}
|
||||
|
||||
static void PyGPTOSSContext_dealloc(PyGPTOSSContext* self) {
|
||||
(void) gptoss_context_release(self->handle);
|
||||
self->handle = NULL;
|
||||
PyObject_Del((PyObject*) self);
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSContext_copy(PyGPTOSSContext *self) {
|
||||
PyGPTOSSContext* copy = (PyGPTOSSContext*) PyObject_New(PyGPTOSSContext, Py_TYPE(self));
|
||||
if (copy == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
(void) gptoss_context_retain(self->handle);
|
||||
copy->handle = self->handle;
|
||||
return (PyObject*) copy;
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSContext_append(PyGPTOSSContext* self, PyObject* arg) {
|
||||
if (PyBytes_Check(arg)) {
|
||||
char* string_ptr = NULL;
|
||||
Py_ssize_t string_size = 0;
|
||||
if (PyBytes_AsStringAndSize(arg, &string_ptr, &string_size) < 0) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
const enum gptoss_status status = gptoss_context_append_chars(
|
||||
self->handle, string_ptr, string_size, /*num_tokens_out=*/NULL);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
} else if (PyUnicode_Check(arg)) {
|
||||
Py_ssize_t string_size = 0;
|
||||
const char* string_ptr = PyUnicode_AsUTF8AndSize(arg, &string_size);
|
||||
if (string_ptr == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
const enum gptoss_status status = gptoss_context_append_chars(
|
||||
self->handle, string_ptr, string_size, /*num_tokens_out=*/NULL);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
} else if (PyLong_Check(arg)) {
|
||||
const unsigned long token_as_ulong = PyLong_AsUnsignedLong(arg);
|
||||
if (token_as_ulong == (unsigned long) -1 && PyErr_Occurred()) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
const uint32_t token = (uint32_t) token_as_ulong;
|
||||
const enum gptoss_status status = gptoss_context_append_tokens(
|
||||
self->handle, /*num_tokens=*/1, &token);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError, "expected a bytes or integer argument");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSContext_process(PyGPTOSSContext* self) {
|
||||
const enum gptoss_status status = gptoss_context_process(self->handle);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSContext_sample(PyGPTOSSContext* self, PyObject* args, PyObject* kwargs) {
|
||||
static char *kwlist[] = {"temperature", "seed", NULL};
|
||||
|
||||
unsigned long long seed = 0;
|
||||
float temperature = 1.0f;
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|$fK", kwlist,
|
||||
&temperature, &seed))
|
||||
{
|
||||
return NULL;
|
||||
}
|
||||
|
||||
uint32_t token_out = UINT32_MAX;
|
||||
enum gptoss_status status = gptoss_context_sample(
|
||||
self->handle, temperature, (uint64_t) seed, &token_out);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyLong_FromUnsignedLong((unsigned long) token_out);
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSContext_reset(PyGPTOSSContext* self) {
|
||||
const enum gptoss_status status = gptoss_context_reset(self->handle);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyMethodDef PyGPTOSSContext_methods[] = {
|
||||
{"__copy__", (PyCFunction) PyGPTOSSContext_copy, METH_NOARGS, "Create a copy of the Context"},
|
||||
{"append", (PyCFunction) PyGPTOSSContext_append, METH_O, "Append bytes to the Context"},
|
||||
{"process", (PyCFunction) PyGPTOSSContext_process, METH_NOARGS, "Process tokens in the Context"},
|
||||
{"sample", (PyCFunction) PyGPTOSSContext_sample, METH_VARARGS | METH_KEYWORDS, "Sample token prediction from the Context"},
|
||||
{"reset", (PyCFunction) PyGPTOSSContext_reset, METH_NOARGS, "Discard the content of the Context"},
|
||||
{NULL},
|
||||
};
|
||||
|
||||
static PyObject* PyGPTOSSContext_get_num_tokens(PyGPTOSSContext* self, void* closure) {
|
||||
size_t num_tokens = 0;
|
||||
const enum gptoss_status status = gptoss_context_get_num_tokens(self->handle, &num_tokens);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyLong_FromSize_t(num_tokens);
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSContext_get_max_tokens(PyGPTOSSContext* self, void* closure) {
|
||||
size_t max_tokens = 0;
|
||||
const enum gptoss_status status = gptoss_context_get_max_tokens(self->handle, &max_tokens);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyLong_FromSize_t(max_tokens);
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure) {
|
||||
PyObject* token_list_obj = NULL;
|
||||
PyObject* token_obj = NULL;
|
||||
uint32_t* token_ptr = NULL;
|
||||
|
||||
size_t num_tokens = 0;
|
||||
gptoss_context_get_tokens(self->handle, /*tokens_out=*/NULL, /*max_tokens=*/0, &num_tokens);
|
||||
|
||||
if (num_tokens != 0) {
|
||||
token_ptr = (uint32_t*) PyMem_Malloc(num_tokens * sizeof(uint32_t));
|
||||
if (token_ptr == NULL) {
|
||||
// TODO: set exception
|
||||
goto error;
|
||||
}
|
||||
|
||||
enum gptoss_status status = gptoss_context_get_tokens(self->handle, token_ptr, /*max_tokens=*/num_tokens, &num_tokens);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
goto error;
|
||||
}
|
||||
}
|
||||
|
||||
token_list_obj = PyList_New((Py_ssize_t) num_tokens);
|
||||
if (token_list_obj == NULL) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
for (size_t t = 0; t < num_tokens; t++) {
|
||||
token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
|
||||
if (token_obj == NULL) {
|
||||
goto error;
|
||||
}
|
||||
if (PyList_SetItem(token_list_obj, (Py_ssize_t) t, token_obj) < 0) {
|
||||
goto error;
|
||||
}
|
||||
token_obj = NULL; // PyList_SetItem stole the reference
|
||||
}
|
||||
|
||||
PyMem_Free(token_ptr);
|
||||
return token_list_obj;
|
||||
|
||||
error:
|
||||
PyMem_Free(token_ptr);
|
||||
Py_XDECREF(token_obj);
|
||||
Py_XDECREF(token_list_obj);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static PyGetSetDef PyGPTOSSContext_getseters[] = {
|
||||
(PyGetSetDef) {
|
||||
.name = "num_tokens",
|
||||
.get = (getter) PyGPTOSSContext_get_num_tokens,
|
||||
.doc = "Current number of tokens in the context",
|
||||
},
|
||||
(PyGetSetDef) {
|
||||
.name = "max_tokens",
|
||||
.get = (getter) PyGPTOSSContext_get_max_tokens,
|
||||
.doc = "Maximum number of tokens in the context",
|
||||
},
|
||||
(PyGetSetDef) {
|
||||
.name = "tokens",
|
||||
.get = (getter) PyGPTOSSContext_get_tokens,
|
||||
.doc = "List of token IDs in the context",
|
||||
},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
PyTypeObject PyGPTOSSContext_Type = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
.tp_name = "gptoss.Context",
|
||||
.tp_basicsize = sizeof(PyGPTOSSContext),
|
||||
.tp_flags = 0
|
||||
| Py_TPFLAGS_DEFAULT
|
||||
| Py_TPFLAGS_BASETYPE,
|
||||
.tp_doc = "Context object",
|
||||
.tp_methods = PyGPTOSSContext_methods,
|
||||
.tp_getset = PyGPTOSSContext_getseters,
|
||||
.tp_new = PyType_GenericNew,
|
||||
.tp_init = (initproc) PyGPTOSSContext_init,
|
||||
.tp_dealloc = (destructor) PyGPTOSSContext_dealloc,
|
||||
};
|
||||
94
gpt_oss/metal/python/model.c
Normal file
94
gpt_oss/metal/python/model.c
Normal file
@@ -0,0 +1,94 @@
|
||||
#include <Python.h>
|
||||
|
||||
#include <gpt-oss.h>
|
||||
|
||||
#include "module.h"
|
||||
|
||||
|
||||
static int PyGPTOSSModel_init(PyGPTOSSModel* self, PyObject* args, PyObject* kwargs) {
|
||||
enum gptoss_status status;
|
||||
const char* filepath;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "s", &filepath)) {
|
||||
return -1;
|
||||
}
|
||||
status = gptoss_model_create_from_file(filepath, &self->handle);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void PyGPTOSSModel_dealloc(PyGPTOSSModel* self) {
|
||||
(void) gptoss_model_release(self->handle);
|
||||
self->handle = NULL;
|
||||
PyObject_Del((PyObject*) self);
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSModel_copy(PyGPTOSSModel* self) {
|
||||
PyGPTOSSModel* copy = (PyGPTOSSModel*) PyObject_New(PyGPTOSSModel, Py_TYPE(self));
|
||||
if (copy == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
(void) gptoss_model_retain(self->handle);
|
||||
copy->handle = self->handle;
|
||||
return (PyObject*) copy;
|
||||
}
|
||||
|
||||
static PyMethodDef PyGPTOSSModel_methods[] = {
|
||||
{"__copy__", (PyCFunction) PyGPTOSSModel_copy, METH_NOARGS, "Create a copy of the Model"},
|
||||
{NULL},
|
||||
};
|
||||
|
||||
static PyObject *PyGPTOSSModel_get_max_context_length(PyGPTOSSModel* self, void* closure) {
|
||||
size_t max_context_length = 0;
|
||||
const enum gptoss_status status = gptoss_model_get_max_context_length(self->handle, &max_context_length);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyLong_FromSize_t(max_context_length);
|
||||
}
|
||||
|
||||
static PyObject *PyGPTOSSModel_get_tokenizer(PyGPTOSSModel* self, void* closure) {
|
||||
PyObject* args = PyTuple_Pack(1, self);
|
||||
if (args == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
PyObject* tokenizer = PyObject_CallObject((PyObject*) &PyGPTOSSTokenizer_Type, args);
|
||||
Py_DECREF(args);
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
static PyGetSetDef PyGPTOSSModel_getseters[] = {
|
||||
(PyGetSetDef) {
|
||||
.name = "max_context_length",
|
||||
.get = (getter) PyGPTOSSModel_get_max_context_length,
|
||||
.doc = "Maximum context length supported by the model",
|
||||
},
|
||||
(PyGetSetDef) {
|
||||
.name = "tokenizer",
|
||||
.get = (getter) PyGPTOSSModel_get_tokenizer,
|
||||
.doc = "Tokenizer object associated with the model",
|
||||
},
|
||||
{NULL} // Sentinel
|
||||
};
|
||||
|
||||
PyTypeObject PyGPTOSSModel_Type = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
.tp_name = "gptoss.Model",
|
||||
.tp_basicsize = sizeof(PyGPTOSSModel),
|
||||
.tp_flags = 0
|
||||
| Py_TPFLAGS_DEFAULT
|
||||
| Py_TPFLAGS_BASETYPE,
|
||||
.tp_doc = "Model object",
|
||||
.tp_methods = PyGPTOSSModel_methods,
|
||||
.tp_getset = PyGPTOSSModel_getseters,
|
||||
.tp_new = PyType_GenericNew,
|
||||
.tp_init = (initproc) PyGPTOSSModel_init,
|
||||
.tp_dealloc = (destructor) PyGPTOSSModel_dealloc,
|
||||
};
|
||||
67
gpt_oss/metal/python/module.c
Normal file
67
gpt_oss/metal/python/module.c
Normal file
@@ -0,0 +1,67 @@
|
||||
#include <Python.h>
|
||||
|
||||
#include "module.h"
|
||||
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{NULL, NULL, 0, NULL}
|
||||
};
|
||||
|
||||
static PyModuleDef metal_module = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"_metal",
|
||||
"Local GPT-OSS inference",
|
||||
-1,
|
||||
module_methods
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC PyInit__metal(void) {
|
||||
PyObject* module = NULL;
|
||||
PyObject* model_type = NULL;
|
||||
PyObject* tokenizer_type = NULL;
|
||||
PyObject* context_type = NULL;
|
||||
|
||||
if (PyType_Ready(&PyGPTOSSModel_Type) < 0) {
|
||||
goto error;
|
||||
}
|
||||
model_type = (PyObject*) &PyGPTOSSModel_Type;
|
||||
Py_INCREF(model_type);
|
||||
|
||||
if (PyType_Ready(&PyGPTOSSTokenizer_Type) < 0) {
|
||||
goto error;
|
||||
}
|
||||
tokenizer_type = (PyObject*) &PyGPTOSSTokenizer_Type;
|
||||
Py_INCREF(tokenizer_type);
|
||||
|
||||
if (PyType_Ready(&PyGPTOSSContext_Type) < 0) {
|
||||
goto error;
|
||||
}
|
||||
context_type = (PyObject*) &PyGPTOSSContext_Type;
|
||||
Py_INCREF(context_type);
|
||||
|
||||
module = PyModule_Create(&metal_module);
|
||||
if (module == NULL) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
if (PyModule_AddObject(module, "Model", model_type) < 0) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
if (PyModule_AddObject(module, "Tokenizer", tokenizer_type) < 0) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
if (PyModule_AddObject(module, "Context", context_type) < 0) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
return module;
|
||||
|
||||
error:
|
||||
Py_XDECREF(context_type);
|
||||
Py_XDECREF(tokenizer_type);
|
||||
Py_XDECREF(model_type);
|
||||
Py_XDECREF(module);
|
||||
return NULL;
|
||||
}
|
||||
22
gpt_oss/metal/python/module.h
Normal file
22
gpt_oss/metal/python/module.h
Normal file
@@ -0,0 +1,22 @@
|
||||
#include <Python.h>
|
||||
|
||||
#include <gpt-oss.h>
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
gptoss_model_t handle;
|
||||
} PyGPTOSSModel;
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
gptoss_tokenizer_t handle;
|
||||
} PyGPTOSSTokenizer;
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
gptoss_context_t handle;
|
||||
} PyGPTOSSContext;
|
||||
|
||||
extern PyTypeObject PyGPTOSSModel_Type;
|
||||
extern PyTypeObject PyGPTOSSTokenizer_Type;
|
||||
extern PyTypeObject PyGPTOSSContext_Type;
|
||||
185
gpt_oss/metal/python/tokenizer.c
Normal file
185
gpt_oss/metal/python/tokenizer.c
Normal file
@@ -0,0 +1,185 @@
|
||||
#include <Python.h>
|
||||
|
||||
#include <gpt-oss.h>
|
||||
|
||||
#include "module.h"
|
||||
|
||||
static PyObject* PyGPTOSSTokenizer_new(PyTypeObject* subtype, PyObject* args, PyObject* kwargs) {
|
||||
static char *kwlist[] = {"model", NULL};
|
||||
PyObject* model = NULL;
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyGPTOSSModel_Type, &model)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
PyGPTOSSTokenizer* self = (PyGPTOSSTokenizer*) subtype->tp_alloc(subtype, 0);
|
||||
if (self == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
const enum gptoss_status status = gptoss_model_get_tokenizer(
|
||||
((const PyGPTOSSModel*) model)->handle,
|
||||
&self->handle);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return (PyObject*) self;
|
||||
}
|
||||
|
||||
static void PyGPTOSSTokenizer_dealloc(PyGPTOSSTokenizer* self) {
|
||||
(void) gptoss_tokenizer_release(self->handle);
|
||||
self->handle = NULL;
|
||||
PyObject_Del((PyObject*) self);
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSTokenizer_copy(PyGPTOSSTokenizer* self) {
|
||||
PyGPTOSSTokenizer* copy = (PyGPTOSSTokenizer*) PyObject_New(PyGPTOSSTokenizer, Py_TYPE(self));
|
||||
if (copy == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
(void) gptoss_tokenizer_retain(self->handle);
|
||||
copy->handle = self->handle;
|
||||
return (PyObject*) copy;
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSTokenizer_encode_special_token(PyGPTOSSTokenizer* self, PyObject* arg) {
|
||||
if (PyUnicode_Check(arg)) {
|
||||
const char* string_ptr = PyUnicode_AsUTF8(arg);
|
||||
if (string_ptr == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
enum gptoss_special_token token_type = gptoss_special_token_invalid;
|
||||
if (strcmp(string_ptr, "<|return|>") == 0) {
|
||||
token_type = gptoss_special_token_return;
|
||||
} else if (strcmp(string_ptr, "<|start|>") == 0) {
|
||||
token_type = gptoss_special_token_start;
|
||||
} else if (strcmp(string_ptr, "<|message|>") == 0) {
|
||||
token_type = gptoss_special_token_message;
|
||||
} else if (strcmp(string_ptr, "<|end|>") == 0) {
|
||||
token_type = gptoss_special_token_end;
|
||||
} else if (strcmp(string_ptr, "<|refusal|>") == 0) {
|
||||
token_type = gptoss_special_token_refusal;
|
||||
} else if (strcmp(string_ptr, "<|constrain|>") == 0) {
|
||||
token_type = gptoss_special_token_constrain;
|
||||
} else if (strcmp(string_ptr, "<|channel|>") == 0) {
|
||||
token_type = gptoss_special_token_channel;
|
||||
} else if (strcmp(string_ptr, "<|call|>") == 0) {
|
||||
token_type = gptoss_special_token_call;
|
||||
} else if (strcmp(string_ptr, "<|untrusted|>") == 0) {
|
||||
token_type = gptoss_special_token_untrusted;
|
||||
} else if (strcmp(string_ptr, "<|end_untrusted|>") == 0) {
|
||||
token_type = gptoss_special_token_end_untrusted;
|
||||
} else {
|
||||
PyErr_Format(PyExc_ValueError, "unrecognized special token: %s", string_ptr);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
uint32_t token_id = UINT32_MAX;
|
||||
const enum gptoss_status status = gptoss_tokenizer_get_special_token_id(
|
||||
self->handle, token_type, &token_id);
|
||||
if (status != gptoss_status_success || token_id == UINT32_MAX) {
|
||||
PyErr_Format(PyExc_ValueError, "tokenizer does not support the %s token", string_ptr);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyLong_FromUnsignedLong((unsigned long) token_id);
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError, "string argument expected");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSTokenizer_decode(PyGPTOSSTokenizer* self, PyObject* args, PyObject* kwargs) {
|
||||
static char *kwlist[] = {"token", NULL};
|
||||
unsigned int token = 0; // Default to 0 if None
|
||||
|
||||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "I", kwlist, &token)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
const void* token_ptr = NULL;
|
||||
size_t token_size = 0;
|
||||
const enum gptoss_status status = gptoss_tokenizer_decode(self->handle, (uint32_t) token, &token_ptr, &token_size);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyBytes_FromStringAndSize((const char*) token_ptr, (Py_ssize_t) token_size);
|
||||
}
|
||||
|
||||
static PyMethodDef PyGPTOSSTokenizer_methods[] = {
|
||||
{"__copy__", (PyCFunction) PyGPTOSSTokenizer_copy, METH_NOARGS, "Create a copy of the Tokenizer"},
|
||||
{"encode_special_token", (PyCFunction) PyGPTOSSTokenizer_encode_special_token, METH_O, "Query ID of a special token"},
|
||||
{"decode", (PyCFunction) PyGPTOSSTokenizer_decode, METH_VARARGS | METH_KEYWORDS, "Convert text token ID to bytes"},
|
||||
{NULL},
|
||||
};
|
||||
|
||||
static PyObject* PyGPTOSSTokenizer_get_num_text_tokens(PyGPTOSSTokenizer* self, void* closure) {
|
||||
uint32_t num_text_tokens = 0;
|
||||
const enum gptoss_status status = gptoss_tokenizer_get_num_text_tokens(self->handle, &num_text_tokens);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyLong_FromUnsignedLong((unsigned long) num_text_tokens);
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSTokenizer_get_num_special_tokens(PyGPTOSSTokenizer* self, void* closure) {
|
||||
uint32_t num_special_tokens = 0;
|
||||
const enum gptoss_status status = gptoss_tokenizer_get_num_special_tokens(self->handle, &num_special_tokens);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyLong_FromUnsignedLong((unsigned long) num_special_tokens);
|
||||
}
|
||||
|
||||
static PyObject* PyGPTOSSTokenizer_get_num_tokens(PyGPTOSSTokenizer* self, void* closure) {
|
||||
uint32_t num_tokens = 0;
|
||||
const enum gptoss_status status = gptoss_tokenizer_get_num_tokens(self->handle, &num_tokens);
|
||||
if (status != gptoss_status_success) {
|
||||
// TODO: set exception
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return PyLong_FromUnsignedLong((unsigned long) num_tokens);
|
||||
}
|
||||
|
||||
static PyGetSetDef PyGPTOSSTokenizer_getseters[] = {
|
||||
(PyGetSetDef) {
|
||||
.name = "num_tokens",
|
||||
.get = (getter) PyGPTOSSTokenizer_get_num_tokens,
|
||||
.doc = "Total number of tokens in the tokenizer dictionary",
|
||||
},
|
||||
(PyGetSetDef) {
|
||||
.name = "num_text_tokens",
|
||||
.get = (getter) PyGPTOSSTokenizer_get_num_text_tokens,
|
||||
.doc = "Number of text tokens in the tokenizer dictionary",
|
||||
},
|
||||
(PyGetSetDef) {
|
||||
.name = "num_special_tokens",
|
||||
.get = (getter) PyGPTOSSTokenizer_get_num_special_tokens,
|
||||
.doc = "Number of special tokens in the tokenizer dictionary",
|
||||
},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
PyTypeObject PyGPTOSSTokenizer_Type = {
|
||||
PyVarObject_HEAD_INIT(NULL, 0)
|
||||
.tp_name = "gptoss.Tokenizer",
|
||||
.tp_basicsize = sizeof(PyGPTOSSTokenizer),
|
||||
.tp_flags = 0
|
||||
| Py_TPFLAGS_DEFAULT
|
||||
| Py_TPFLAGS_BASETYPE,
|
||||
.tp_doc = "Tokenizer object",
|
||||
.tp_methods = PyGPTOSSTokenizer_methods,
|
||||
.tp_getset = PyGPTOSSTokenizer_getseters,
|
||||
.tp_new = PyGPTOSSTokenizer_new,
|
||||
.tp_dealloc = (destructor) PyGPTOSSTokenizer_dealloc,
|
||||
};
|
||||
343
gpt_oss/metal/scripts/create-local-model.py
Normal file
343
gpt_oss/metal/scripts/create-local-model.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import argparse
|
||||
import os
|
||||
import math
|
||||
import sys
|
||||
import json
|
||||
import itertools
|
||||
import struct
|
||||
from uuid import UUID
|
||||
|
||||
import tiktoken
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
from openai_harmony import load_harmony_encoding, HarmonyEncodingName
|
||||
|
||||
parser = argparse.ArgumentParser(prog='check-mxfp4-weights.py', description='Validated MXFP4 weights')
|
||||
parser.add_argument('-s', '--src', metavar='DIR', type=str, required=True, help='Path to the input checkpoint directory')
|
||||
parser.add_argument('-d', '--dst', metavar='FILE', type=str, required=True, help='Path to the output model file')
|
||||
|
||||
|
||||
o200k_base = tiktoken.get_encoding("o200k_base")
|
||||
harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
o200k_gptoss = tiktoken.Encoding(
|
||||
name="o200k_gptoss",
|
||||
pat_str=o200k_base._pat_str,
|
||||
mergeable_ranks=o200k_base._mergeable_ranks,
|
||||
special_tokens={
|
||||
"<|reversed199998|>": 199998, # unused
|
||||
"<|endoftext|>": 199999,
|
||||
"<|untrusted|>": 200000,
|
||||
"<|endofuntrusted|>": 200001,
|
||||
"<|return|>": 200002,
|
||||
"<|constrain|>": 200003,
|
||||
"<|reversed200004|>": 200004, # unused
|
||||
"<|channel|>": 200005,
|
||||
"<|start|>": 200006,
|
||||
"<|end|>": 200007,
|
||||
"<|message|>": 200008,
|
||||
"<|reversed200008|>": 200008, # unused
|
||||
"<|reversed200009|>": 200009, # unused
|
||||
"<|reversed200010|>": 200010, # unused
|
||||
"<|reversed200011|>": 200011, # unused
|
||||
"<|call|>": 200012,
|
||||
"<|refusal|>": 200013,
|
||||
}
|
||||
)
|
||||
|
||||
FILE_MAGIC = struct.pack('ccccccccccccI', b'G', b'P', b'T', b'-', b'O', b'S', b'S', b' ', b'v', b'1', b'.', b'0', 0)
|
||||
SPECIAL_TOKEN_UUID = {
|
||||
'<|start|>': UUID('55a77c2f-8a01-4c54-8ac2-313bfc7e208d').bytes,
|
||||
'<|message|>': UUID('16e40431-f47f-4b22-b59b-8b278fc30a54').bytes,
|
||||
'<|end|>': UUID('fcac2f6d-4705-4f6b-b228-642accac7238').bytes,
|
||||
'<|return|>': UUID('f799ff69-1992-43c4-a3d8-d831f475dc75').bytes,
|
||||
'<|refusal|>': UUID('e15ba702-28c4-4292-ab8f-ffa434709128').bytes,
|
||||
'<|constrain|>': UUID('c0bb14c7-6022-49da-ad08-792d67e8b470').bytes,
|
||||
'<|channel|>': UUID('fd3dda11-c8ab-4033-876e-d93deb172c93').bytes,
|
||||
'<|call|>': UUID('1220f796-e388-4de5-b487-fe2eb5fe03c0').bytes,
|
||||
'<|untrusted|>': UUID('07d7da55-b346-4cff-8b37-7cefacf8a3e8').bytes,
|
||||
'<|end_untrusted|>': UUID('f265bd9c-c717-469e-a447-920687d65d90').bytes,
|
||||
}
|
||||
|
||||
INCLUDE_SPECIAL_TOKENS = [
|
||||
"<|start|>",
|
||||
"<|message|>",
|
||||
"<|end|>",
|
||||
"<|return|>",
|
||||
"<|refusal|>",
|
||||
"<|constrain|>",
|
||||
"<|channel|>",
|
||||
"<|call|>",
|
||||
"<|untrusted|>",
|
||||
"<|end_untrusted|>",
|
||||
]
|
||||
|
||||
GPTOSS_MODEL_UUID = UUID('df52dc86-1789-4ed0-a295-66f10508145b').bytes
|
||||
APPLE_GPU_LAYOUT_UUID = UUID('229177a8-5775-4268-bfd8-d588b351c56d').bytes
|
||||
TIKTOKEN_TOKENIZER_UUID = UUID('7401aded-2a95-40cb-b782-9ccebaafe72b').bytes
|
||||
|
||||
UE8_OFFSET = 14 # bias to MXFP4 block scales
|
||||
|
||||
def write_file_header(f):
|
||||
f.write(FILE_MAGIC)
|
||||
|
||||
def write_tokenizer_header(f,
|
||||
num_special_tokens: int,
|
||||
num_text_tokens: int,
|
||||
regex_size: int,
|
||||
tokens_size: int):
|
||||
f.write(TIKTOKEN_TOKENIZER_UUID)
|
||||
f.write(struct.pack('<I', num_special_tokens))
|
||||
f.write(struct.pack('<I', num_text_tokens))
|
||||
f.write(struct.pack('<I', regex_size))
|
||||
f.write(struct.pack('<I', tokens_size))
|
||||
|
||||
def write_model_header(f,
|
||||
context_length : int,
|
||||
num_blocks : int,
|
||||
num_experts : int,
|
||||
num_active_experts : int,
|
||||
embedding_dim : int,
|
||||
mlp_dim : int,
|
||||
swiglu_limit : float,
|
||||
head_dim: int,
|
||||
num_heads : int,
|
||||
num_kv_heads : int,
|
||||
attention_window : int,
|
||||
rope_theta : float,
|
||||
interpolation_scale : float,
|
||||
yarn_offset : float,
|
||||
yarn_scale : float,
|
||||
yarn_multiplier : float,
|
||||
rmsnorm_epsilon : float):
|
||||
f.write(GPTOSS_MODEL_UUID)
|
||||
f.write(struct.pack('<I', context_length))
|
||||
f.write(struct.pack('<I', num_blocks))
|
||||
f.write(struct.pack('<I', num_experts))
|
||||
f.write(struct.pack('<I', num_active_experts))
|
||||
f.write(struct.pack('<I', embedding_dim))
|
||||
f.write(struct.pack('<I', mlp_dim))
|
||||
f.write(struct.pack('<f', swiglu_limit))
|
||||
f.write(struct.pack('<I', head_dim))
|
||||
f.write(struct.pack('<I', num_heads))
|
||||
f.write(struct.pack('<I', num_kv_heads))
|
||||
f.write(struct.pack('<I', attention_window))
|
||||
f.write(struct.pack('<f', rope_theta))
|
||||
f.write(struct.pack('<f', interpolation_scale))
|
||||
f.write(struct.pack('<f', yarn_offset))
|
||||
f.write(struct.pack('<f', yarn_scale))
|
||||
f.write(struct.pack('<f', yarn_multiplier))
|
||||
f.write(struct.pack('<f', rmsnorm_epsilon))
|
||||
f.write(APPLE_GPU_LAYOUT_UUID)
|
||||
|
||||
|
||||
def write_padding(out_file, alignment_multiple=16384):
|
||||
offset = out_file.tell()
|
||||
alignment_size = -offset % alignment_multiple
|
||||
if alignment_size != 0:
|
||||
alignment = bytes(alignment_size)
|
||||
out_file.write(alignment)
|
||||
|
||||
|
||||
def write_embedding_weight(out_file, weight):
|
||||
write_padding(out_file, alignment_multiple=16)
|
||||
|
||||
assert weight.dtype == torch.float8_e4m3fn or weight.dtype == torch.bfloat16
|
||||
out_file.write(weight.view(torch.uint8).numpy().tobytes())
|
||||
|
||||
|
||||
def write_rmsnorm_gain(out_file, gain):
|
||||
write_padding(out_file, alignment_multiple=16)
|
||||
|
||||
assert gain.dtype == torch.bfloat16
|
||||
out_file.write(gain.view(torch.uint8).numpy().tobytes())
|
||||
|
||||
|
||||
def write_attn_sink(out_file, sink):
|
||||
write_padding(out_file, alignment_multiple=16)
|
||||
|
||||
assert sink.dtype == torch.bfloat16
|
||||
out_file.write(sink.view(torch.uint8).numpy().tobytes())
|
||||
|
||||
|
||||
def write_linear_weight(out_file, *args):
|
||||
write_padding(out_file, alignment_multiple=16)
|
||||
|
||||
for t in args:
|
||||
out_file.write(t.view(torch.uint8).numpy().tobytes())
|
||||
|
||||
|
||||
def main(args):
|
||||
options = parser.parse_args(args)
|
||||
|
||||
with open(os.path.join(options.src, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
num_blocks = config["num_hidden_layers"]
|
||||
num_experts = config["num_experts"]
|
||||
num_active_experts = 4
|
||||
num_q_heads = config["num_attention_heads"]
|
||||
num_kv_heads = config["num_key_value_heads"]
|
||||
head_dim = config["head_dim"]
|
||||
embedding_dim = config["hidden_size"]
|
||||
mlp_dim = config["intermediate_size"]
|
||||
swiglu_limit = config.get("swiglu_limit", 7.0)
|
||||
rope_theta = config["rope_theta"]
|
||||
attention_window = config["sliding_window"]
|
||||
initial_context_length = config["initial_context_length"]
|
||||
rope_scaling_factor = config["rope_scaling_factor"]
|
||||
rope_ntk_alpha = config["rope_ntk_alpha"]
|
||||
rope_ntk_beta = config["rope_ntk_beta"]
|
||||
|
||||
tokens_size = 0
|
||||
num_text_tokens = 0
|
||||
# First add all text tokens
|
||||
for t in range(o200k_gptoss.n_vocab):
|
||||
if not harmony_encoding.is_special_token(t):
|
||||
token_bytes = o200k_gptoss.decode_single_token_bytes(t)
|
||||
assert len(token_bytes) > 0
|
||||
tokens_size += len(token_bytes) + 2 # uint16_t string length + string data
|
||||
num_text_tokens += 1
|
||||
# Then add all special tokens
|
||||
num_included_tokens = 200013 + 1
|
||||
print(f"Tokenizer: {num_included_tokens} tokens")
|
||||
|
||||
tensors = {}
|
||||
with open(options.dst, "wb") as dst:
|
||||
with safe_open(os.path.join(options.src, "model.safetensors"), framework="pt", device="cpu") as src:
|
||||
write_file_header(dst)
|
||||
|
||||
yarn_low = (
|
||||
head_dim / 2
|
||||
* math.log(initial_context_length / (rope_ntk_beta * 2 * math.pi))
|
||||
/ math.log(rope_theta)
|
||||
)
|
||||
yarn_high = (
|
||||
head_dim / 2
|
||||
* math.log(initial_context_length / (rope_ntk_alpha * 2 * math.pi))
|
||||
/ math.log(rope_theta)
|
||||
)
|
||||
|
||||
write_model_header(dst,
|
||||
context_length=int(initial_context_length * rope_scaling_factor),
|
||||
num_blocks=num_blocks,
|
||||
num_experts=num_experts,
|
||||
num_active_experts=num_active_experts,
|
||||
embedding_dim=embedding_dim,
|
||||
mlp_dim=mlp_dim,
|
||||
swiglu_limit=swiglu_limit,
|
||||
head_dim=head_dim,
|
||||
num_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
attention_window=attention_window,
|
||||
rope_theta=rope_theta,
|
||||
interpolation_scale=1.0 / rope_scaling_factor,
|
||||
yarn_offset=-yarn_low / (yarn_high - yarn_low),
|
||||
yarn_scale=1.0 / (yarn_high - yarn_low),
|
||||
yarn_multiplier=0.1 * math.log(rope_scaling_factor) + 1.0,
|
||||
rmsnorm_epsilon=1.0e-5)
|
||||
|
||||
write_tokenizer_header(dst,
|
||||
num_special_tokens=num_included_tokens - num_text_tokens,
|
||||
num_text_tokens=num_text_tokens,
|
||||
regex_size=len(o200k_gptoss._pat_str.encode("ascii")) + 1,
|
||||
tokens_size=tokens_size)
|
||||
|
||||
### Tokenizer
|
||||
# Special tokens
|
||||
for token_idx in range(num_text_tokens, num_included_tokens):
|
||||
token = o200k_gptoss.decode_single_token_bytes(token_idx).decode('ascii')
|
||||
if token in INCLUDE_SPECIAL_TOKENS:
|
||||
dst.write(SPECIAL_TOKEN_UUID[token])
|
||||
else:
|
||||
dst.write(bytes(16))
|
||||
# Regex
|
||||
dst.write(o200k_gptoss._pat_str.encode("ascii"))
|
||||
dst.write(struct.pack('B', 0))
|
||||
# Text tokens
|
||||
tokenizer_bytes_written = 0
|
||||
for t in range(num_text_tokens):
|
||||
token_bytes = o200k_gptoss.decode_single_token_bytes(t)
|
||||
assert len(token_bytes) > 0
|
||||
dst.write(struct.pack('<H', len(token_bytes)))
|
||||
dst.write(token_bytes)
|
||||
tokenizer_bytes_written += len(token_bytes) + 2
|
||||
assert(tokenizer_bytes_written == tokens_size), (tokenizer_bytes_written, tokens_size)
|
||||
write_padding(dst)
|
||||
|
||||
embedding_weight = src.get_tensor("embedding.weight")
|
||||
# Filter out unused tokens
|
||||
embedding_weight = embedding_weight[:num_included_tokens, :]
|
||||
write_embedding_weight(dst, embedding_weight)
|
||||
|
||||
for n in tqdm(range(num_blocks)):
|
||||
write_rmsnorm_gain(dst, src.get_tensor(f"block.{n}.attn.norm.scale"))
|
||||
|
||||
attn_qkv_weight = src.get_tensor(f"block.{n}.attn.qkv.weight")
|
||||
attn_qkv_bias = src.get_tensor(f"block.{n}.attn.qkv.bias")
|
||||
for qkv in (attn_qkv_weight, attn_qkv_bias):
|
||||
qk = qkv[:head_dim * (num_q_heads + num_kv_heads), ...].contiguous()
|
||||
v = qkv[head_dim * (num_q_heads + num_kv_heads):, ...].contiguous()
|
||||
qk = qk.view(num_q_heads + num_kv_heads, 2, head_dim // 2, -1).transpose(1, 2).reshape(num_q_heads + num_kv_heads, head_dim, -1)
|
||||
q = qk[:num_q_heads, ...]
|
||||
k = qk[num_q_heads:, ...]
|
||||
# Factor multiplication by 1/sqrt(64) = 0.125 = 0.5 * 0.25 in SDPA into Q and K projections
|
||||
assert head_dim == 64
|
||||
q *= 0.5
|
||||
k *= 0.25
|
||||
v = v.view(num_kv_heads, head_dim, -1)
|
||||
qkv.copy_(torch.cat((q, k, v), dim=0).reshape(*qkv.shape))
|
||||
|
||||
write_linear_weight(dst, attn_qkv_weight, attn_qkv_bias)
|
||||
|
||||
write_attn_sink(dst, src.get_tensor(f"block.{n}.attn.sinks"))
|
||||
|
||||
write_linear_weight(dst, src.get_tensor(f"block.{n}.attn.out.weight"), src.get_tensor(f"block.{n}.attn.out.bias"))
|
||||
|
||||
write_rmsnorm_gain(dst, src.get_tensor(f"block.{n}.mlp.norm.scale"))
|
||||
|
||||
write_linear_weight(dst, src.get_tensor(f"block.{n}.mlp.gate.weight"), src.get_tensor(f"block.{n}.mlp.gate.bias"))
|
||||
|
||||
write_rmsnorm_gain(dst, src.get_tensor("norm.scale"))
|
||||
|
||||
unembedding_weight = src.get_tensor("unembedding.weight")
|
||||
unembedding_weight = unembedding_weight[:num_included_tokens, :]
|
||||
write_linear_weight(dst, unembedding_weight)
|
||||
|
||||
for n in tqdm(range(num_blocks)):
|
||||
mlp1_blocks = src.get_tensor(f"block.{n}.mlp.mlp1_weight.blocks")
|
||||
mlp1_scales = src.get_tensor(f"block.{n}.mlp.mlp1_weight.scales")
|
||||
assert mlp1_scales.min().item() < 254 - UE8_OFFSET
|
||||
mlp1_bias = src.get_tensor(f"block.{n}.mlp.mlp1_bias")
|
||||
|
||||
mlp2_blocks = src.get_tensor(f"block.{n}.mlp.mlp2_weight.blocks")
|
||||
mlp2_scales = src.get_tensor(f"block.{n}.mlp.mlp2_weight.scales")
|
||||
assert mlp2_scales.min().item() < 254 - UE8_OFFSET
|
||||
mlp2_bias = src.get_tensor(f"block.{n}.mlp.mlp2_bias")
|
||||
|
||||
# Write MoE weights grouped by expert
|
||||
write_padding(dst)
|
||||
|
||||
for e in range(num_experts):
|
||||
write_padding(dst, alignment_multiple=16)
|
||||
dst.write(mlp1_blocks[e, ...].view(torch.uint8).numpy().tobytes())
|
||||
|
||||
write_padding(dst, alignment_multiple=16)
|
||||
dst.write((mlp1_scales + UE8_OFFSET)[e, ...].view(torch.uint8).numpy().tobytes())
|
||||
|
||||
write_padding(dst, alignment_multiple=16)
|
||||
dst.write(mlp1_bias[e, ...].view(torch.uint8).numpy().tobytes())
|
||||
|
||||
write_padding(dst, alignment_multiple=16)
|
||||
dst.write(mlp2_blocks[e, ...].view(torch.uint8).numpy().tobytes())
|
||||
|
||||
write_padding(dst, alignment_multiple=16)
|
||||
dst.write((mlp2_scales + UE8_OFFSET)[e, ...].view(torch.uint8).numpy().tobytes())
|
||||
|
||||
write_padding(dst, alignment_multiple=16)
|
||||
dst.write(mlp2_bias[e, ...].view(torch.uint8).numpy().tobytes())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
||||
55
gpt_oss/metal/source/accumulate.metal
Normal file
55
gpt_oss/metal/source/accumulate.metal
Normal file
@@ -0,0 +1,55 @@
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
|
||||
kernel void gptoss_f32_accumulate_e4(
|
||||
constant gptoss_accumulate_args& args [[ buffer(0) ]],
|
||||
const device float4* input [[ buffer(1) ]],
|
||||
const device gptoss_expert_prediction* expert [[ buffer(2) ]],
|
||||
device float4* output [[ buffer(3) ]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint2 threadgroup_size [[ threads_per_threadgroup ]])
|
||||
{
|
||||
const uint num_active_experts = 4;
|
||||
|
||||
const uint num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
|
||||
const uint threadgroup_start = gid.x * num_vecs_per_threadgroup;
|
||||
const uint num_vecs = args.num_vecs;
|
||||
const uint threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, num_vecs);
|
||||
const uint thread_start = threadgroup_start + tid;
|
||||
uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size.x - 1)) / threadgroup_size.x);
|
||||
|
||||
const uint num_vecs_per_expert = args.num_vecs_per_expert;
|
||||
const float scale0 = expert[gid.y * num_active_experts + 0].score;
|
||||
const device float4* input0 = input + gid.y * num_vecs + thread_start;
|
||||
const float scale1 = expert[gid.y * num_active_experts + 1].score;
|
||||
const device float4* input1 = input0 + num_vecs_per_expert;
|
||||
const float scale2 = expert[gid.y * num_active_experts + 2].score;
|
||||
const device float4* input2 = input1 + num_vecs_per_expert;
|
||||
const float scale3 = expert[gid.y * num_active_experts + 3].score;
|
||||
const device float4* input3 = input2 + num_vecs_per_expert;
|
||||
output += gid.y * num_vecs + thread_start;
|
||||
for (; num_iter != 0; num_iter--) {
|
||||
float4 acc = *output;
|
||||
const float4 val0 = *input0;
|
||||
const float4 val1 = *input1;
|
||||
const float4 val2 = *input2;
|
||||
const float4 val3 = *input3;
|
||||
input0 += threadgroup_size.x;
|
||||
acc = metal::fma(val0, scale0, acc);
|
||||
input1 += threadgroup_size.x;
|
||||
acc = metal::fma(val1, scale1, acc);
|
||||
input2 += threadgroup_size.x;
|
||||
acc = metal::fma(val2, scale2, acc);
|
||||
input3 += threadgroup_size.x;
|
||||
acc = metal::fma(val3, scale3, acc);
|
||||
*output = acc;
|
||||
output += threadgroup_size.x;
|
||||
}
|
||||
}
|
||||
717
gpt_oss/metal/source/context.c
Normal file
717
gpt_oss/metal/source/context.c
Normal file
@@ -0,0 +1,717 @@
|
||||
#include <assert.h>
|
||||
#include <float.h>
|
||||
#include <inttypes.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <gpt-oss.h>
|
||||
|
||||
#include "internal/datatype.h"
|
||||
#include "internal/model.h"
|
||||
#include "internal/metal.h"
|
||||
#include "internal/metal-kernels.h"
|
||||
#include "internal/log.h"
|
||||
#include "internal/rng.h"
|
||||
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_create(
|
||||
gptoss_model_t model,
|
||||
size_t context_length,
|
||||
gptoss_context_t* context_out)
|
||||
{
|
||||
*context_out = NULL;
|
||||
|
||||
enum gptoss_status status = gptoss_status_success;
|
||||
struct gptoss_context* context = NULL;
|
||||
|
||||
if (context_length == 0) {
|
||||
context_length = model->context_length;
|
||||
} else if (context_length > model->context_length) {
|
||||
GPTOSS_LOG_ERROR("requested context length %zu exceeds model context length %" PRIu32,
|
||||
context_length, model->context_length);
|
||||
status = gptoss_status_invalid_argument;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
context = malloc(sizeof(struct gptoss_context));
|
||||
if (context == NULL) {
|
||||
GPTOSS_LOG_ERROR("failed to allocate %zu bytes for Context object",
|
||||
sizeof(struct gptoss_context));
|
||||
status = gptoss_status_insufficient_memory;
|
||||
goto cleanup;
|
||||
}
|
||||
memset(context, 0, sizeof(struct gptoss_context));
|
||||
|
||||
atomic_store_explicit(&context->ref_count, 1, memory_order_relaxed);
|
||||
context->max_tokens = context_length;
|
||||
|
||||
status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->score_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->prob_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->max_threadgroups * sizeof(float), NULL, &context->sum_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * sizeof(uint64_t), NULL, &context->argmax_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->num_blocks * context_length * 2 * model->num_kv_heads * model->head_dim * sizeof(float), NULL, &context->kvcache_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
context->kvcache_size = context->kvcache_buffer.size;
|
||||
context->allocation_size = context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;
|
||||
|
||||
context->model = model;
|
||||
gptoss_model_retain(model);
|
||||
*context_out = context;
|
||||
context = NULL;
|
||||
|
||||
cleanup:
|
||||
gptoss_context_release(context);
|
||||
return status;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(
|
||||
gptoss_context_t context,
|
||||
size_t* num_tokens_out)
|
||||
{
|
||||
*num_tokens_out = context->num_tokens;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(
|
||||
gptoss_context_t context,
|
||||
size_t* max_tokens_out)
|
||||
{
|
||||
*max_tokens_out = context->max_tokens;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
|
||||
gptoss_context_t context,
|
||||
uint32_t* tokens_out,
|
||||
size_t max_tokens,
|
||||
size_t* num_tokens_out)
|
||||
{
|
||||
*num_tokens_out = context->num_tokens;
|
||||
if (max_tokens < context->num_tokens) {
|
||||
return gptoss_status_insufficient_memory;
|
||||
}
|
||||
|
||||
if (context->num_tokens != 0) {
|
||||
memcpy(tokens_out, context->token_buffer.ptr, context->num_tokens * sizeof(uint32_t));
|
||||
}
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
static enum gptoss_status process_batch(
|
||||
gptoss_context_t context)
|
||||
{
|
||||
enum gptoss_status status = gptoss_status_success;
|
||||
const struct gptoss_model* model = context->model;
|
||||
struct gptoss_metal_command_buffer command_buffer = {0};
|
||||
|
||||
const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
|
||||
|
||||
status = gptoss_metal_command_buffer_create(&model->command_queue, &command_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
|
||||
&command_buffer,
|
||||
&model->bf16_f32_embeddings_fn,
|
||||
/*threadgroup_size=*/512,
|
||||
&context->token_buffer,
|
||||
(context->num_tokens - context->num_batch_tokens) * sizeof(uint32_t),
|
||||
&model->shared_weight_buffer,
|
||||
/*weight_offset=*/0,
|
||||
&model->residual_activation_buffer,
|
||||
/*output_offset=*/0,
|
||||
/*num_tokens=*/context->num_batch_tokens,
|
||||
/*num_channels=*/model->embedding_dim);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
for (uint32_t n = 0; n < model->num_blocks; n++) {
|
||||
const bool last_block = n + 1 == model->num_blocks;
|
||||
const size_t num_output_tokens = last_block ? 1 : context->num_batch_tokens;
|
||||
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
||||
&command_buffer,
|
||||
&model->f32_bf16w_rmsnorm_fn,
|
||||
&model->residual_activation_buffer,
|
||||
/*input_offset=*/0,
|
||||
&model->shared_weight_buffer,
|
||||
/*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
|
||||
&model->rmsnorm_activation_buffer,
|
||||
/*output_offset=*/0,
|
||||
/*num_tokens=*/context->num_batch_tokens,
|
||||
/*num_channels=*/model->embedding_dim,
|
||||
model->rmsnorm_epsilon);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
|
||||
&command_buffer,
|
||||
&model->f32_bf16w_matmul_fn,
|
||||
/*threadgroup_size=*/256,
|
||||
&model->rmsnorm_activation_buffer,
|
||||
/*input_offset=*/0,
|
||||
&model->shared_weight_buffer,
|
||||
/*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
|
||||
&model->shared_weight_buffer,
|
||||
/*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
|
||||
&model->qkv_activation_buffer,
|
||||
/*output_offset=*/0,
|
||||
/*num_tokens=*/context->num_batch_tokens,
|
||||
/*num_cols=*/model->embedding_dim,
|
||||
/*num_rows=*/attn_qkv_dim);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_rope(
|
||||
&command_buffer,
|
||||
&model->f32_rope_fn,
|
||||
/*threadgroup_size=*/32,
|
||||
&model->qkv_activation_buffer,
|
||||
model->rope_theta,
|
||||
model->interpolation_scale,
|
||||
model->yarn_offset,
|
||||
model->yarn_scale,
|
||||
model->yarn_multiplier,
|
||||
context->num_batch_tokens,
|
||||
model->num_heads,
|
||||
model->num_kv_heads,
|
||||
model->head_dim,
|
||||
/*token_offset=*/context->num_kv_tokens);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
for (uint32_t t = 0; t < context->num_batch_tokens; t++) {
|
||||
status = gptoss_metal_command_buffer_encode_copy_buffer(
|
||||
&command_buffer,
|
||||
&model->qkv_activation_buffer,
|
||||
/*input_offset=*/(t * attn_qkv_dim + model->num_heads * model->head_dim) * sizeof(float),
|
||||
&context->kvcache_buffer,
|
||||
/*output_offset=*/(n * context->max_tokens + context->num_kv_tokens + t) * 2 * model->num_kv_heads * model->head_dim * sizeof(float),
|
||||
/*size=*/2 * model->num_kv_heads * model->head_dim * sizeof(float));
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t);
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_sdpa(
|
||||
&command_buffer,
|
||||
&model->f32_sdpa_q8_d64_fn,
|
||||
&model->qkv_activation_buffer,
|
||||
/*q_offset=*/attn_qkv_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
|
||||
&context->kvcache_buffer,
|
||||
/*k_offset=*/n * context->max_tokens * 2 * model->num_kv_heads * model->head_dim * sizeof(float),
|
||||
&context->kvcache_buffer,
|
||||
/*v_offset=*/(n * context->max_tokens * 2 + 1) * model->num_kv_heads * model->head_dim * sizeof(float),
|
||||
&model->shared_weight_buffer,
|
||||
/*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,
|
||||
&model->sdpa_activation_buffer, /*output_offset=*/0,
|
||||
/*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,
|
||||
num_output_tokens, context->num_kv_tokens + (context->num_batch_tokens - num_output_tokens),
|
||||
model->num_heads, model->num_kv_heads, model->head_dim);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
|
||||
&command_buffer,
|
||||
&model->f32_bf16w_matmul_fn,
|
||||
/*threadgroup_size=*/256,
|
||||
&model->sdpa_activation_buffer,
|
||||
/*input_offset=*/0,
|
||||
&model->shared_weight_buffer,
|
||||
/*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
|
||||
&model->shared_weight_buffer,
|
||||
/*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
|
||||
&model->residual_activation_buffer,
|
||||
/*output_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
|
||||
/*num_tokens=*/num_output_tokens,
|
||||
/*num_cols=*/model->num_heads * model->head_dim,
|
||||
/*num_rows=*/model->embedding_dim);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
||||
&command_buffer,
|
||||
&model->f32_bf16w_rmsnorm_fn,
|
||||
&model->residual_activation_buffer,
|
||||
/*input_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
|
||||
&model->shared_weight_buffer,
|
||||
/*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
|
||||
&model->rmsnorm_activation_buffer,
|
||||
/*output_offset=*/0,
|
||||
num_output_tokens,
|
||||
model->embedding_dim,
|
||||
model->rmsnorm_epsilon);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
|
||||
&command_buffer,
|
||||
&model->f32_bf16w_matmul_fn,
|
||||
/*threadgroup_size=*/256,
|
||||
&model->rmsnorm_activation_buffer,
|
||||
/*input_offset=*/0,
|
||||
&model->shared_weight_buffer,
|
||||
/*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
|
||||
&model->shared_weight_buffer,
|
||||
/*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
|
||||
&model->gate_activation_buffer,
|
||||
/*output_offset=*/0,
|
||||
/*num_tokens=*/num_output_tokens,
|
||||
/*num_cols=*/model->embedding_dim,
|
||||
/*num_rows=*/model->num_experts);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
const char* kernel_name = NULL;
|
||||
switch (model->num_experts) {
|
||||
case 32:
|
||||
kernel_name = "f32_topk_softmax_e32_k4_fn";
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
|
||||
&command_buffer,
|
||||
&model->f32_topk_softmax_e32_k4_fn,
|
||||
&model->gate_activation_buffer, /*input_offset=*/0,
|
||||
&model->expert_activation_buffer, /*output_offset=*/0,
|
||||
num_output_tokens,
|
||||
model->num_experts,
|
||||
model->num_active_experts);
|
||||
break;
|
||||
case 128:
|
||||
kernel_name = "f32_topk_softmax_e128_k4_fn";
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
|
||||
&command_buffer,
|
||||
&model->f32_topk_softmax_e128_k4_fn,
|
||||
&model->gate_activation_buffer, /*input_offset=*/0,
|
||||
&model->expert_activation_buffer, /*output_offset=*/0,
|
||||
num_output_tokens,
|
||||
model->num_experts,
|
||||
model->num_active_experts);
|
||||
break;
|
||||
default:
|
||||
status = gptoss_status_unsupported_argument;
|
||||
GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts);
|
||||
goto cleanup;
|
||||
}
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name);
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
|
||||
&command_buffer,
|
||||
&model->f32_mf4w_moe_matmul_swiglu_fn,
|
||||
/*threadgroup_size=*/512,
|
||||
&model->rmsnorm_activation_buffer, /*input_offset=*/0,
|
||||
&model->expert_activation_buffer, /*expert_offset=*/0,
|
||||
&model->block_weight_buffers[n], /*weight_block_offset=*/0,
|
||||
&model->block_weight_buffers[n], /*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
|
||||
&model->block_weight_buffers[n], /*bias_offset=*/model->mlp_swiglu_bias_offset,
|
||||
&model->swiglu_activation_buffer, /*output_offset=*/0,
|
||||
model->swiglu_limit,
|
||||
model->per_expert_block_weight_size,
|
||||
num_output_tokens,
|
||||
model->num_active_experts,
|
||||
model->embedding_dim,
|
||||
model->mlp_dim);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
|
||||
&command_buffer,
|
||||
&model->f32_mf4w_moe_matmul_fn,
|
||||
/*threadgroup_size=*/512,
|
||||
&model->swiglu_activation_buffer, /*input_offset=*/0,
|
||||
&model->expert_activation_buffer, /*expert_offset=*/0,
|
||||
&model->block_weight_buffers[n], /*weight_block_offset=*/model->mlp_out_block_offset,
|
||||
&model->block_weight_buffers[n], /*weight_scale_offset=*/model->mlp_out_scale_offset,
|
||||
&model->block_weight_buffers[n], /*bias_offset=*/model->mlp_out_bias_offset,
|
||||
&model->moe_activation_buffer, /*output_offset=*/0,
|
||||
model->per_expert_block_weight_size,
|
||||
num_output_tokens,
|
||||
model->num_active_experts,
|
||||
model->mlp_dim,
|
||||
model->embedding_dim);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(
|
||||
&command_buffer,
|
||||
&model->f32_accumulate_e4_fn,
|
||||
/*threadgroup_size=*/256,
|
||||
model->max_threadgroups,
|
||||
&model->moe_activation_buffer,
|
||||
/*input_offset=*/0,
|
||||
&model->expert_activation_buffer,
|
||||
/*expert_offset=*/0,
|
||||
&model->residual_activation_buffer,
|
||||
/*output_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
|
||||
model->embedding_dim,
|
||||
num_output_tokens,
|
||||
model->num_active_experts);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t num_output_tokens = 1;
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
||||
&command_buffer,
|
||||
&model->f32_bf16w_rmsnorm_fn,
|
||||
&model->residual_activation_buffer,
|
||||
/*input_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
|
||||
&model->shared_weight_buffer,
|
||||
/*weight_offset=*/model->rmsnorm_weight_offset,
|
||||
&model->rmsnorm_activation_buffer,
|
||||
/*output_offset=*/0,
|
||||
/*num_tokens=*/num_output_tokens,
|
||||
/*num_channels=*/model->embedding_dim,
|
||||
model->rmsnorm_epsilon);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
status = gptoss_metal_command_buffer_encode_fill_buffer(
|
||||
&command_buffer,
|
||||
&context->argmax_buffer,
|
||||
/*offset=*/0,
|
||||
/*size=*/sizeof(uint64_t) * num_output_tokens,
|
||||
/*fill_value=*/0xFF);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode fill buffer command");
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
|
||||
&command_buffer,
|
||||
&model->f32_bf16w_unembedding_fn,
|
||||
/*threadgroup_size=*/256,
|
||||
model->max_threadgroups,
|
||||
&model->rmsnorm_activation_buffer,
|
||||
/*input_offset=*/0,
|
||||
&model->shared_weight_buffer,
|
||||
/*weight_offset=*/model->unembedding_weight_offset,
|
||||
&context->score_buffer,
|
||||
/*output_offset=*/0,
|
||||
&context->argmax_buffer,
|
||||
/*argmax_offset=*/0,
|
||||
/*num_tokens=*/num_output_tokens,
|
||||
/*num_cols=*/model->embedding_dim,
|
||||
/*num_rows=*/model->vocabulary_size);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch");
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
gptoss_metal_command_buffer_commit(&command_buffer);
|
||||
gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
|
||||
|
||||
context->num_kv_tokens = context->num_tokens;
|
||||
context->num_processed_tokens = num_output_tokens;
|
||||
context->num_batch_tokens = 0;
|
||||
|
||||
cleanup:
|
||||
gptoss_metal_command_buffer_release(&command_buffer);
|
||||
return status;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
|
||||
gptoss_context_t context,
|
||||
const char* text,
|
||||
size_t text_length,
|
||||
size_t* num_tokens_out)
|
||||
{
|
||||
enum gptoss_status status = gptoss_status_success;
|
||||
const struct gptoss_model* model = context->model;
|
||||
const struct gptoss_tokenizer* tokenizer = model->tokenizer;
|
||||
size_t num_appended_tokens = 0;
|
||||
while (text_length != 0) {
|
||||
if (context->num_tokens == context->max_tokens) {
|
||||
status = gptoss_status_context_overflow;
|
||||
break;
|
||||
}
|
||||
const char* tokens = tokenizer->tokens_ptr;
|
||||
uint32_t best_token = UINT32_MAX;
|
||||
uint32_t best_token_length = 0;
|
||||
for (size_t t = 0; t < tokenizer->num_text_tokens; t++) {
|
||||
uint16_t token_length;
|
||||
memcpy(&token_length, tokens, sizeof(uint16_t));
|
||||
tokens += sizeof(uint16_t);
|
||||
if (token_length <= text_length && token_length > best_token_length) {
|
||||
if (memcmp(text, tokens, token_length) == 0) {
|
||||
if (token_length > best_token_length) {
|
||||
best_token = (uint32_t) t;
|
||||
best_token_length = token_length;
|
||||
}
|
||||
}
|
||||
}
|
||||
tokens += token_length;
|
||||
}
|
||||
|
||||
if (best_token == UINT32_MAX) {
|
||||
GPTOSS_LOG_ERROR("failed to tokenize text \"%.*s\"", (int) text_length, text);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
|
||||
input_tokens[context->num_tokens] = best_token;
|
||||
context->num_tokens++;
|
||||
num_appended_tokens++;
|
||||
if (++context->num_batch_tokens == model->max_batch_tokens) {
|
||||
status = process_batch(context);
|
||||
if (status != gptoss_status_success) {
|
||||
break;
|
||||
}
|
||||
assert(context->num_batch_tokens == 0);
|
||||
}
|
||||
assert(context->num_batch_tokens < model->max_batch_tokens);
|
||||
text += best_token_length;
|
||||
text_length -= best_token_length;
|
||||
}
|
||||
if (num_tokens_out != NULL) {
|
||||
*num_tokens_out = num_appended_tokens;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
|
||||
gptoss_context_t context,
|
||||
size_t num_tokens,
|
||||
const uint32_t* tokens)
|
||||
{
|
||||
const struct gptoss_model* model = context->model;
|
||||
|
||||
// Validate all tokens
|
||||
for (size_t t = 0; t < num_tokens; t++) {
|
||||
const uint32_t token = tokens[t];
|
||||
if (token >= model->vocabulary_size) {
|
||||
GPTOSS_LOG_ERROR("token %" PRIu32 " at index %zu is out of bounds for vocabulary size %" PRIu32,
|
||||
token, t, context->model->vocabulary_size);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
}
|
||||
|
||||
enum gptoss_status status = gptoss_status_success;
|
||||
uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
|
||||
while (num_tokens != 0) {
|
||||
assert(context->num_batch_tokens < model->max_batch_tokens);
|
||||
if (context->num_tokens == context->max_tokens) {
|
||||
status = gptoss_status_context_overflow;
|
||||
break;
|
||||
}
|
||||
|
||||
const size_t num_tokens_to_copy =
|
||||
math_min(context->max_tokens - context->num_tokens,
|
||||
math_min(num_tokens, model->max_batch_tokens - context->num_batch_tokens));
|
||||
memcpy(input_tokens + context->num_tokens, tokens, num_tokens_to_copy * sizeof(uint32_t));
|
||||
context->num_tokens += num_tokens_to_copy;
|
||||
context->num_batch_tokens += num_tokens_to_copy;
|
||||
if (context->num_batch_tokens == model->max_batch_tokens) {
|
||||
status = process_batch(context);
|
||||
if (status != gptoss_status_success) {
|
||||
break;
|
||||
}
|
||||
assert(context->num_batch_tokens == 0);
|
||||
}
|
||||
tokens += num_tokens_to_copy;
|
||||
num_tokens -= num_tokens_to_copy;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_process(
|
||||
gptoss_context_t context)
|
||||
{
|
||||
if (context->num_batch_tokens != 0) {
|
||||
process_batch(context);
|
||||
}
|
||||
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_sample(
|
||||
gptoss_context_t context,
|
||||
float temperature,
|
||||
uint64_t seed,
|
||||
uint32_t* token_out)
|
||||
{
|
||||
enum gptoss_status status = gptoss_status_success;
|
||||
const struct gptoss_model* model = context->model;
|
||||
struct gptoss_metal_command_buffer command_buffer = {0};
|
||||
|
||||
*token_out = UINT32_MAX;
|
||||
if (context->num_batch_tokens != 0) {
|
||||
status = process_batch(context);
|
||||
if (status != gptoss_status_success) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
if (temperature == 0.0f) {
|
||||
const uint64_t argmax_bits = ((const uint64_t*) context->argmax_buffer.ptr)[0];
|
||||
*token_out = (uint32_t) argmax_bits;
|
||||
} else {
|
||||
assert(context->num_processed_tokens != 0);
|
||||
status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
uint32_t num_threadgroups = 0;
|
||||
uint32_t num_dims_per_threadgroup = 0;
|
||||
status = gptoss_metal_command_buffer_encode_launch_f32_softmax(
|
||||
&command_buffer,
|
||||
&model->f32_softmax_fn,
|
||||
/*threadgroup_size=*/256,
|
||||
model->max_threadgroups,
|
||||
&context->score_buffer,
|
||||
/*score_offset=*/0,
|
||||
&context->argmax_buffer,
|
||||
/*argmax_offset=*/0,
|
||||
&context->prob_buffer,
|
||||
/*prob_offset=*/0,
|
||||
&context->sum_buffer,
|
||||
/*sum_offset=*/0,
|
||||
model->vocabulary_size,
|
||||
/*num_tokens=*/1,
|
||||
temperature,
|
||||
&num_threadgroups,
|
||||
&num_dims_per_threadgroup);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch");
|
||||
}
|
||||
|
||||
gptoss_metal_command_buffer_commit(&command_buffer);
|
||||
gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
|
||||
|
||||
const uint32_t sample_word = rng_squares32(context->num_tokens, seed + UINT64_C(0x123456789ABCDEF));
|
||||
float sample_cdf = (float) ((int32_t) sample_word & INT32_C(0x00FFFFFF)) * 0x1.0p-24f;
|
||||
|
||||
const float* sum_ptr = (const float*) context->sum_buffer.ptr;
|
||||
float sum = 0.0f;
|
||||
for (uint32_t i = 0; i < num_threadgroups; i++) {
|
||||
sum += sum_ptr[i];
|
||||
}
|
||||
sample_cdf *= sum;
|
||||
|
||||
uint32_t block_idx = 0, token_idx = 0;
|
||||
if (sample_cdf == 0.0f) {
|
||||
// Make sure we choose the first token with non-zero probability rather than just the first token
|
||||
sample_cdf = FLT_TRUE_MIN;
|
||||
}
|
||||
|
||||
// Step 1: find block
|
||||
float cumsum = 0.0f;
|
||||
for (; block_idx < num_threadgroups; block_idx++) {
|
||||
const float new_cumsum = cumsum + sum_ptr[block_idx];
|
||||
if (new_cumsum >= sample_cdf) {
|
||||
break;
|
||||
}
|
||||
cumsum = new_cumsum;
|
||||
}
|
||||
if (block_idx == num_threadgroups) {
|
||||
block_idx -= 1;
|
||||
}
|
||||
|
||||
// Step 2: find token
|
||||
const float* prob_ptr = (const float*) context->prob_buffer.ptr + block_idx * num_dims_per_threadgroup;
|
||||
assert(model->vocabulary_size > num_dims_per_threadgroup * block_idx);
|
||||
uint32_t num_dims_per_block = math_min(num_dims_per_threadgroup, model->vocabulary_size - num_dims_per_threadgroup * block_idx);
|
||||
for (; token_idx < num_dims_per_block; token_idx++) {
|
||||
const float new_cumsum = cumsum + prob_ptr[token_idx];
|
||||
if (new_cumsum >= sample_cdf) {
|
||||
break;
|
||||
}
|
||||
cumsum = new_cumsum;
|
||||
}
|
||||
if (token_idx == num_dims_per_block) {
|
||||
token_idx -= 1;
|
||||
}
|
||||
|
||||
token_idx += block_idx * num_dims_per_threadgroup;
|
||||
|
||||
*token_out = token_idx;
|
||||
|
||||
cleanup:
|
||||
gptoss_metal_command_buffer_release(&command_buffer);
|
||||
return status;
|
||||
}
|
||||
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_reset(
|
||||
gptoss_context_t context)
|
||||
{
|
||||
context->num_tokens = 0;
|
||||
context->num_kv_tokens = 0;
|
||||
context->num_batch_tokens = 0;
|
||||
context->num_processed_tokens = 0;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_retain(
|
||||
gptoss_context_t context)
|
||||
{
|
||||
atomic_fetch_add_explicit(&context->ref_count, 1, memory_order_relaxed);
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_context_release(
|
||||
gptoss_context_t context)
|
||||
{
|
||||
if (context != NULL) {
|
||||
if (atomic_fetch_sub_explicit(&context->ref_count, 1, memory_order_acq_rel) == 1) {
|
||||
gptoss_metal_buffer_release(&context->token_buffer);
|
||||
gptoss_metal_buffer_release(&context->score_buffer);
|
||||
gptoss_metal_buffer_release(&context->prob_buffer);
|
||||
gptoss_metal_buffer_release(&context->sum_buffer);
|
||||
gptoss_metal_buffer_release(&context->argmax_buffer);
|
||||
gptoss_metal_buffer_release(&context->kvcache_buffer);
|
||||
|
||||
gptoss_model_release(context->model);
|
||||
|
||||
memset(context, 0, sizeof(struct gptoss_context));
|
||||
free(context);
|
||||
}
|
||||
}
|
||||
return gptoss_status_success;
|
||||
}
|
||||
64
gpt_oss/metal/source/convert.metal
Normal file
64
gpt_oss/metal/source/convert.metal
Normal file
@@ -0,0 +1,64 @@
|
||||
#include <metal_integer>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
|
||||
kernel void gptoss_mf4_f32_convert(
|
||||
constant gptoss_convert_args& args [[ buffer(0) ]],
|
||||
const device uint4* blocks [[ buffer(1) ]],
|
||||
const device uchar* scales [[ buffer(2) ]],
|
||||
device float4* output [[ buffer(3) ]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_position_in_threadgroup]],
|
||||
uint threadgroup_size [[ threads_per_threadgroup ]])
|
||||
{
|
||||
const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
|
||||
const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
|
||||
const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
|
||||
const ulong thread_start = threadgroup_start + tid;
|
||||
uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
|
||||
|
||||
blocks += thread_start;
|
||||
scales += thread_start;
|
||||
output += 8 * thread_start;
|
||||
for (; num_iter != 0; num_iter--) {
|
||||
const uint4 block = *blocks;
|
||||
const float scale = as_type<float>((static_cast<uint>(*scales) + 14) << 23);
|
||||
uint4 block02468ACEGIKMOQSU = block + block;
|
||||
uint4 block13579BDFHJLNPRTV = block >> 3;
|
||||
block02468ACEGIKMOQSU &= 0x1E1E1E1Eu;
|
||||
block13579BDFHJLNPRTV &= 0x1E1E1E1Eu;
|
||||
block02468ACEGIKMOQSU += 0x70707070u;
|
||||
block13579BDFHJLNPRTV += 0x70707070u;
|
||||
block02468ACEGIKMOQSU &= 0x8E8E8E8Eu;
|
||||
block13579BDFHJLNPRTV &= 0x8E8E8E8Eu;
|
||||
const uint4 block26AEIMQU = block02468ACEGIKMOQSU & 0xFF00FF00u;
|
||||
const uint4 block048CGKOS = (block02468ACEGIKMOQSU << 8) & 0xFF00FF00u;
|
||||
const uint4 block37BFJNRV = block13579BDFHJLNPRTV & 0xFF00FF00u;
|
||||
const uint4 block159DHLPT = (block13579BDFHJLNPRTV << 8) & 0xFF00FF00u;
|
||||
const float4 block048C = static_cast<float4>(as_type<half4>(block048CGKOS.xy)) * scale;
|
||||
const float4 blockGKOS = static_cast<float4>(as_type<half4>(block048CGKOS.zw)) * scale;
|
||||
const float4 block26AE = static_cast<float4>(as_type<half4>(block26AEIMQU.xy)) * scale;
|
||||
const float4 blockIMQU = static_cast<float4>(as_type<half4>(block26AEIMQU.zw)) * scale;
|
||||
const float4 block159D = static_cast<float4>(as_type<half4>(block159DHLPT.xy)) * scale;
|
||||
const float4 blockHLPT = static_cast<float4>(as_type<half4>(block159DHLPT.zw)) * scale;
|
||||
const float4 block37BF = static_cast<float4>(as_type<half4>(block37BFJNRV.xy)) * scale;
|
||||
const float4 blockJNRV = static_cast<float4>(as_type<half4>(block37BFJNRV.zw)) * scale;
|
||||
|
||||
output[0] = (float4) { block048C.x, block159D.x, block26AE.x, block37BF.x };
|
||||
output[1] = (float4) { block048C.y, block159D.y, block26AE.y, block37BF.y };
|
||||
output[2] = (float4) { block048C.z, block159D.z, block26AE.z, block37BF.z };
|
||||
output[3] = (float4) { block048C.w, block159D.w, block26AE.w, block37BF.w };
|
||||
output[4] = (float4) { blockGKOS.x, blockHLPT.x, blockIMQU.x, blockJNRV.x };
|
||||
output[5] = (float4) { blockGKOS.y, blockHLPT.y, blockIMQU.y, blockJNRV.y };
|
||||
output[6] = (float4) { blockGKOS.z, blockHLPT.z, blockIMQU.z, blockJNRV.z };
|
||||
output[7] = (float4) { blockGKOS.w, blockHLPT.w, blockIMQU.w, blockJNRV.w };
|
||||
|
||||
blocks += threadgroup_size;
|
||||
scales += threadgroup_size;
|
||||
output += 8 * threadgroup_size;
|
||||
}
|
||||
}
|
||||
24
gpt_oss/metal/source/embeddings.metal
Normal file
24
gpt_oss/metal/source/embeddings.metal
Normal file
@@ -0,0 +1,24 @@
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
|
||||
kernel void gptoss_bf16_f32_embeddings(
|
||||
constant gptoss_embeddings_args& args [[ buffer(0) ]],
|
||||
const device uint* tokens [[ buffer(1) ]],
|
||||
const device bfloat4* weights [[ buffer(2) ]],
|
||||
device float4* output [[ buffer(3) ]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_position_in_threadgroup]],
|
||||
uint threadgroup_size [[ threads_per_threadgroup ]])
|
||||
{
|
||||
const uint t = tokens[gid];
|
||||
|
||||
weights += t * args.num_vecs;
|
||||
output += gid * args.num_vecs;
|
||||
for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {
|
||||
const bfloat4 w = weights[i];
|
||||
output[i] = static_cast<float4>(w);
|
||||
}
|
||||
}
|
||||
316
gpt_oss/metal/source/generate.c
Normal file
316
gpt_oss/metal/source/generate.c
Normal file
@@ -0,0 +1,316 @@
|
||||
#include <assert.h>
|
||||
#include <inttypes.h>
|
||||
#include <math.h>
|
||||
#include <signal.h>
|
||||
#include <stdatomic.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <mach/mach_time.h>
|
||||
|
||||
#include <gpt-oss.h>
|
||||
|
||||
#include "internal/model.h"
|
||||
|
||||
struct {
|
||||
atomic_uint_least64_t inference_bytes;
|
||||
atomic_size_t num_prefill_tokens;
|
||||
atomic_uint_least64_t prefill_microseconds;
|
||||
atomic_size_t num_generated_tokens;
|
||||
atomic_uint_least64_t generation_microseconds;
|
||||
} globals = {
|
||||
.inference_bytes = 0,
|
||||
.num_prefill_tokens = 0,
|
||||
.prefill_microseconds = 0,
|
||||
.num_generated_tokens = 0,
|
||||
.generation_microseconds = 0,
|
||||
};
|
||||
|
||||
struct options {
|
||||
const char* model;
|
||||
const char* prompt;
|
||||
size_t context_length;
|
||||
size_t max_tokens;
|
||||
float temperature;
|
||||
bool verbose;
|
||||
};
|
||||
|
||||
static inline double mach_timestamp_diff_to_seconds(uint64_t start_timestamp, uint64_t end_timestamp) {
|
||||
static mach_timebase_info_data_t timebase_info = {0};
|
||||
if (timebase_info.denom == 0) {
|
||||
mach_timebase_info(&timebase_info);
|
||||
}
|
||||
const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
|
||||
return ((double) elapsed_mach_time * (double) timebase_info.numer) / ((double) timebase_info.denom * 1.0e+9);
|
||||
}
|
||||
|
||||
static inline uint64_t mach_timestamp_diff_to_microseconds(uint64_t start_timestamp, uint64_t end_timestamp) {
|
||||
static mach_timebase_info_data_t timebase_info = {0};
|
||||
if (timebase_info.denom == 0) {
|
||||
mach_timebase_info(&timebase_info);
|
||||
}
|
||||
const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
|
||||
const uint64_t denominator = timebase_info.denom * UINT64_C(1000);
|
||||
return (elapsed_mach_time * timebase_info.numer + denominator / 2) / denominator;
|
||||
}
|
||||
|
||||
static void print_usage(const char* program_name) {
|
||||
printf("Usage: %s <model-path> [-p <prompt>] [-n <tokens>]\n", program_name);
|
||||
}
|
||||
|
||||
struct options parse_options(int argc, char** argv) {
|
||||
struct options options = (struct options) {
|
||||
.model = NULL,
|
||||
.prompt = NULL,
|
||||
.context_length = 0,
|
||||
.max_tokens = 0,
|
||||
.temperature = 0.0f,
|
||||
.verbose = false,
|
||||
};
|
||||
if (argc < 2) {
|
||||
fprintf(stderr, "Error: missing required command-line argument\n");
|
||||
print_usage(argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
for (int i = 1; i < argc; i++) {
|
||||
if (strcmp(argv[i], "--help") == 0) {
|
||||
print_usage(argv[0]);
|
||||
exit(EXIT_SUCCESS);
|
||||
} else if (strcmp(argv[i], "-p") == 0 || strcmp(argv[i], "--prompt") == 0) {
|
||||
if (i + 1 >= argc) {
|
||||
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
|
||||
print_usage(argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
options.prompt = argv[++i];
|
||||
} else if (strcmp(argv[i], "--context-length") == 0) {
|
||||
if (i + 1 >= argc) {
|
||||
fprintf(stderr, "Error: missing argument for --context-length\n");
|
||||
print_usage(argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
char* context_length_start = argv[++i];
|
||||
char* context_length_end = context_length_start;
|
||||
options.context_length = strtoul(context_length_start, &context_length_end, 10);
|
||||
if (context_length_end == context_length_start || *context_length_end != 0) {
|
||||
fprintf(stderr, "Error: failed to parse context length value \"%s\"\n", context_length_start);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
} else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--max-tokens") == 0) {
|
||||
if (i + 1 >= argc) {
|
||||
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
|
||||
print_usage(argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
char* max_tokens_start = argv[++i];
|
||||
char* max_tokens_end = max_tokens_start;
|
||||
options.max_tokens = strtoul(max_tokens_start, &max_tokens_end, 10);
|
||||
if (max_tokens_end == max_tokens_start || *max_tokens_end != 0) {
|
||||
fprintf(stderr, "Error: failed to max tokens value \"%s\"\n", max_tokens_start);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
if (options.max_tokens == 0) {
|
||||
fprintf(stderr, "Error: invalid max tokens value %zu\n", options.max_tokens);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
} else if (strcmp(argv[i], "-t") == 0 || strcmp(argv[i], "--temperature") == 0) {
|
||||
if (i + 1 >= argc) {
|
||||
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
|
||||
print_usage(argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
char* temperature_start = argv[++i];
|
||||
char* temperature_end = temperature_start;
|
||||
options.temperature = strtof(temperature_start, &temperature_end);
|
||||
if (temperature_end == temperature_start || *temperature_end != 0) {
|
||||
fprintf(stderr, "Error: failed to parse temperature value \"%s\"\n", temperature_start);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
if (signbit(options.temperature) != 0 || !(options.temperature <= 2.0f)) {
|
||||
fprintf(stderr, "Error: invalid temperature value %f\n", options.temperature);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
} else if (strcmp(argv[i], "-v") == 0 || strcmp(argv[i], "--verbose") == 0) {
|
||||
options.verbose = true;
|
||||
} else {
|
||||
if (options.model == NULL) {
|
||||
options.model = argv[i];
|
||||
} else {
|
||||
fprintf(stderr, "Error: unexpected command-line argument %s\n", argv[i]);
|
||||
print_usage(argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (options.model == NULL) {
|
||||
fprintf(stderr, "Error: missing required model argument\n");
|
||||
print_usage(argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
if (options.prompt == NULL) {
|
||||
fprintf(stderr, "Error: missing required prompt argument\n");
|
||||
print_usage(argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
return options;
|
||||
}
|
||||
|
||||
|
||||
static void print_profile() {
|
||||
const size_t num_prefill_tokens = atomic_load(&globals.num_prefill_tokens);
|
||||
const uint64_t prefill_microseconds = atomic_load(&globals.prefill_microseconds);
|
||||
const size_t num_generated_tokens = atomic_load(&globals.num_generated_tokens) - 1;
|
||||
const uint64_t generation_microseconds = atomic_load(&globals.generation_microseconds);
|
||||
const uint64_t inference_bytes = atomic_load(&globals.inference_bytes);
|
||||
if (num_prefill_tokens != 0 || num_generated_tokens != 0) {
|
||||
printf("\n");
|
||||
}
|
||||
if (num_prefill_tokens != 0) {
|
||||
printf("Prefill speed (%zu tokens): %.1f tokens/second\n",
|
||||
num_prefill_tokens,
|
||||
(double) num_prefill_tokens / (double) prefill_microseconds * 1.0e+6);
|
||||
}
|
||||
if (num_generated_tokens > 5) {
|
||||
printf("Generation speed (%zu tokens, excluding the first 5): %.1f tokens/second\n",
|
||||
(num_generated_tokens - 5),
|
||||
(double) (num_generated_tokens - 5) / (double) generation_microseconds * 1.0e+6);
|
||||
}
|
||||
}
|
||||
|
||||
static void ctrl_c_handler(int signum) {
|
||||
print_profile();
|
||||
exit(EXIT_SUCCESS);
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
enum gptoss_status status;
|
||||
gptoss_model_t model = NULL;
|
||||
gptoss_tokenizer_t tokenizer = NULL;
|
||||
gptoss_context_t context = NULL;
|
||||
|
||||
struct sigaction act;
|
||||
act.sa_handler = ctrl_c_handler;
|
||||
sigaction(SIGINT, &act, NULL);
|
||||
|
||||
setvbuf(stdout, NULL, _IONBF, 0);
|
||||
|
||||
struct options options = parse_options(argc, argv);
|
||||
|
||||
const uint64_t load_start_time = mach_continuous_time();
|
||||
status = gptoss_model_create_from_file(options.model, &model);
|
||||
if (status != gptoss_status_success) {
|
||||
fprintf(stderr, "Error: failed to load model from file %s\n", options.model);
|
||||
goto error;
|
||||
}
|
||||
size_t max_model_context_length = 0;
|
||||
status = gptoss_model_get_max_context_length(model, &max_model_context_length);
|
||||
if (status != gptoss_status_success) {
|
||||
fprintf(stderr, "Error: failed to query maximum context length\n");
|
||||
goto error;
|
||||
}
|
||||
assert(max_model_context_length != 0);
|
||||
if (options.context_length == 0) {
|
||||
options.context_length = max_model_context_length;
|
||||
} else if (options.context_length > max_model_context_length) {
|
||||
fprintf(stderr, "Error: context length %zu exceeds maximum context length %zu supported by the model\n", options.context_length, max_model_context_length);
|
||||
goto error;
|
||||
}
|
||||
|
||||
status = gptoss_model_get_tokenizer(model, &tokenizer);
|
||||
if (status != gptoss_status_success) {
|
||||
fprintf(stderr, "Error: failed to retrieve Tokenizer\n");
|
||||
goto error;
|
||||
}
|
||||
|
||||
uint32_t return_token_id = UINT32_MAX;
|
||||
status = gptoss_tokenizer_get_special_token_id(tokenizer, gptoss_special_token_return, &return_token_id);
|
||||
if (status != gptoss_status_success) {
|
||||
fprintf(stderr, "Error: failed to query end-of-text token ID\n");
|
||||
goto error;
|
||||
}
|
||||
|
||||
status = gptoss_context_create(model, options.context_length, &context);
|
||||
if (status != gptoss_status_success) {
|
||||
fprintf(stderr, "Error: failed to create Context object\n");
|
||||
goto error;
|
||||
}
|
||||
if (options.verbose) {
|
||||
printf("Model weights size: %.2lf MB\n", (double) model->weights_size * 0x1.0p-20);
|
||||
printf("Model allocation size: %.2lf MB\n", (double) model->allocation_size * 0x1.0p-20);
|
||||
printf("Context allocation size: %.2lf MB\n", (double) context->allocation_size * 0x1.0p-20);
|
||||
printf(" Including KV cache: %.2lf MB\n", (double) context->kvcache_size * 0x1.0p-20);
|
||||
}
|
||||
|
||||
const uint64_t load_end_time = mach_continuous_time();
|
||||
const double load_elapsed_seconds = mach_timestamp_diff_to_seconds(load_start_time, load_end_time);
|
||||
if (options.verbose) {
|
||||
printf("Loaded model in %.3f seconds\n", load_elapsed_seconds);
|
||||
}
|
||||
|
||||
const uint64_t prefill_start_time = mach_continuous_time();
|
||||
size_t num_prefill_tokens = 0;
|
||||
status = gptoss_context_append_chars(context, options.prompt, strlen(options.prompt), &num_prefill_tokens);
|
||||
if (status != gptoss_status_success) {
|
||||
fprintf(stderr, "Error: failed to tokenize prompt \"%s\"\n", options.prompt);
|
||||
goto error;
|
||||
}
|
||||
atomic_store(&globals.num_prefill_tokens, num_prefill_tokens);
|
||||
status = gptoss_context_process(context);
|
||||
if (status != gptoss_status_success) {
|
||||
fprintf(stderr, "Error: failed to process Context object\n");
|
||||
goto error;
|
||||
}
|
||||
const uint64_t prefill_end_time = mach_continuous_time();
|
||||
|
||||
while (options.max_tokens == 0 || atomic_load(&globals.num_generated_tokens) < options.max_tokens) {
|
||||
|
||||
uint32_t predicted_token = UINT32_MAX;
|
||||
const uint64_t inference_start_timestamp = mach_continuous_time();
|
||||
status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, &predicted_token);
|
||||
if (status != gptoss_status_success) {
|
||||
fprintf(stderr, "Error: failed to sample from the Context object\n");
|
||||
goto error;
|
||||
}
|
||||
const uint64_t inference_end_timestamp = mach_continuous_time();
|
||||
|
||||
if (predicted_token == return_token_id) {
|
||||
// Yield token -> stop generation
|
||||
break;
|
||||
}
|
||||
|
||||
// Unembedding: detokenize
|
||||
size_t token_size = 0;
|
||||
const void* token_ptr = NULL;
|
||||
status = gptoss_tokenizer_decode(tokenizer, predicted_token, &token_ptr, &token_size);
|
||||
if (status != gptoss_status_success) {
|
||||
fprintf(stderr, "Error: failed to detokenize predicted token %" PRIu32 "\n", predicted_token);
|
||||
goto error;
|
||||
}
|
||||
const size_t previous_num_generated_tokens = atomic_fetch_add(&globals.num_generated_tokens, 1);
|
||||
if (previous_num_generated_tokens == 0) {
|
||||
atomic_fetch_add(&globals.prefill_microseconds, mach_timestamp_diff_to_microseconds(prefill_start_time, prefill_end_time));
|
||||
} else if (previous_num_generated_tokens > 5) {
|
||||
atomic_fetch_add(&globals.generation_microseconds, mach_timestamp_diff_to_microseconds(inference_start_timestamp, inference_end_timestamp));
|
||||
}
|
||||
printf("%.*s", (int) token_size, (const char*) token_ptr);
|
||||
|
||||
status = gptoss_context_append_tokens(context, 1, &predicted_token);
|
||||
if (status != gptoss_status_success) {
|
||||
fprintf(stderr, "Error: failed to append predicted token %" PRIu32 " to context\n", predicted_token);
|
||||
goto error;
|
||||
}
|
||||
}
|
||||
|
||||
print_profile();
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
|
||||
error:
|
||||
gptoss_context_release(context);
|
||||
gptoss_tokenizer_release(tokenizer);
|
||||
gptoss_model_release(model);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
41
gpt_oss/metal/source/include/internal/datatype.h
Normal file
41
gpt_oss/metal/source/include/internal/datatype.h
Normal file
@@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <internal/macros.h>
|
||||
|
||||
|
||||
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
||||
GPTOSS_ALIGN(2) uint16_t bits;
|
||||
} gptoss_bfloat16;
|
||||
static_assert(sizeof(gptoss_bfloat16) == 2, "bfloat16 size is not 2 bytes");
|
||||
|
||||
|
||||
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
||||
GPTOSS_ALIGN(2) uint16_t bits;
|
||||
} gptoss_float16;
|
||||
static_assert(sizeof(gptoss_float16) == 2, "float16 size is not 2 bytes");
|
||||
|
||||
|
||||
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
||||
GPTOSS_ALIGN(1) uint8_t bits;
|
||||
} gptoss_float8ue8m0;
|
||||
static_assert(sizeof(gptoss_float8ue8m0) == 1, "gptoss_float8ue8m0 size is not 1 bytes");
|
||||
|
||||
|
||||
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
||||
GPTOSS_ALIGN(1) uint8_t bits;
|
||||
} gptoss_float8e5m2;
|
||||
static_assert(sizeof(gptoss_float8e5m2) == 1, "float8e5m2 size is not 1 bytes");
|
||||
|
||||
|
||||
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
||||
GPTOSS_ALIGN(1) uint8_t bits;
|
||||
} gptoss_float8e4m3;
|
||||
static_assert(sizeof(gptoss_float8e4m3) == 1, "gptoss_float8e4m3 size is not 1 bytes");
|
||||
|
||||
|
||||
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
||||
GPTOSS_ALIGN(1) uint8_t bits;
|
||||
} gptoss_float4e2m1x2;
|
||||
static_assert(sizeof(gptoss_float4e2m1x2) == 1, "gptoss_float4e2m1x2 size is not 1 bytes");
|
||||
87
gpt_oss/metal/source/include/internal/datatype.hpp
Normal file
87
gpt_oss/metal/source/include/internal/datatype.hpp
Normal file
@@ -0,0 +1,87 @@
|
||||
#pragma once
|
||||
|
||||
#include <bit>
|
||||
|
||||
#include <internal/datatype.h>
|
||||
|
||||
|
||||
namespace gptoss {
|
||||
|
||||
template <typename WideT, typename NarrowT>
|
||||
WideT upcast(NarrowT);
|
||||
|
||||
template <>
|
||||
inline float upcast<float>(gptoss_bfloat16 bf16_value) {
|
||||
const uint32_t bits = static_cast<uint32_t>(bf16_value.bits) << 16;
|
||||
return std::bit_cast<float>(bits);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline float upcast<float>(gptoss_float16 fp16_value) {
|
||||
return static_cast<float>(std::bit_cast<_Float16>(fp16_value.bits));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline float upcast<float>(gptoss_float8e4m3 fp8_value) {
|
||||
static constexpr uint16_t fp8e4m3_to_fp32[256] = {
|
||||
0x0000, 0x3B00, 0x3B80, 0x3BC0, 0x3C00, 0x3C20, 0x3C40, 0x3C60,
|
||||
0x3C80, 0x3C90, 0x3CA0, 0x3CB0, 0x3CC0, 0x3CD0, 0x3CE0, 0x3CF0,
|
||||
0x3D00, 0x3D10, 0x3D20, 0x3D30, 0x3D40, 0x3D50, 0x3D60, 0x3D70,
|
||||
0x3D80, 0x3D90, 0x3DA0, 0x3DB0, 0x3DC0, 0x3DD0, 0x3DE0, 0x3DF0,
|
||||
0x3E00, 0x3E10, 0x3E20, 0x3E30, 0x3E40, 0x3E50, 0x3E60, 0x3E70,
|
||||
0x3E80, 0x3E90, 0x3EA0, 0x3EB0, 0x3EC0, 0x3ED0, 0x3EE0, 0x3EF0,
|
||||
0x3F00, 0x3F10, 0x3F20, 0x3F30, 0x3F40, 0x3F50, 0x3F60, 0x3F70,
|
||||
0x3F80, 0x3F90, 0x3FA0, 0x3FB0, 0x3FC0, 0x3FD0, 0x3FE0, 0x3FF0,
|
||||
0x4000, 0x4010, 0x4020, 0x4030, 0x4040, 0x4050, 0x4060, 0x4070,
|
||||
0x4080, 0x4090, 0x40A0, 0x40B0, 0x40C0, 0x40D0, 0x40E0, 0x40F0,
|
||||
0x4100, 0x4110, 0x4120, 0x4130, 0x4140, 0x4150, 0x4160, 0x4170,
|
||||
0x4180, 0x4190, 0x41A0, 0x41B0, 0x41C0, 0x41D0, 0x41E0, 0x41F0,
|
||||
0x4200, 0x4210, 0x4220, 0x4230, 0x4240, 0x4250, 0x4260, 0x4270,
|
||||
0x4280, 0x4290, 0x42A0, 0x42B0, 0x42C0, 0x42D0, 0x42E0, 0x42F0,
|
||||
0x4300, 0x4310, 0x4320, 0x4330, 0x4340, 0x4350, 0x4360, 0x4370,
|
||||
0x4380, 0x4390, 0x43A0, 0x43B0, 0x43C0, 0x43D0, 0x43E0, 0x7FF0,
|
||||
0x8000, 0xBB00, 0xBB80, 0xBBC0, 0xBC00, 0xBC20, 0xBC40, 0xBC60,
|
||||
0xBC80, 0xBC90, 0xBCA0, 0xBCB0, 0xBCC0, 0xBCD0, 0xBCE0, 0xBCF0,
|
||||
0xBD00, 0xBD10, 0xBD20, 0xBD30, 0xBD40, 0xBD50, 0xBD60, 0xBD70,
|
||||
0xBD80, 0xBD90, 0xBDA0, 0xBDB0, 0xBDC0, 0xBDD0, 0xBDE0, 0xBDF0,
|
||||
0xBE00, 0xBE10, 0xBE20, 0xBE30, 0xBE40, 0xBE50, 0xBE60, 0xBE70,
|
||||
0xBE80, 0xBE90, 0xBEA0, 0xBEB0, 0xBEC0, 0xBED0, 0xBEE0, 0xBEF0,
|
||||
0xBF00, 0xBF10, 0xBF20, 0xBF30, 0xBF40, 0xBF50, 0xBF60, 0xBF70,
|
||||
0xBF80, 0xBF90, 0xBFA0, 0xBFB0, 0xBFC0, 0xBFD0, 0xBFE0, 0xBFF0,
|
||||
0xC000, 0xC010, 0xC020, 0xC030, 0xC040, 0xC050, 0xC060, 0xC070,
|
||||
0xC080, 0xC090, 0xC0A0, 0xC0B0, 0xC0C0, 0xC0D0, 0xC0E0, 0xC0F0,
|
||||
0xC100, 0xC110, 0xC120, 0xC130, 0xC140, 0xC150, 0xC160, 0xC170,
|
||||
0xC180, 0xC190, 0xC1A0, 0xC1B0, 0xC1C0, 0xC1D0, 0xC1E0, 0xC1F0,
|
||||
0xC200, 0xC210, 0xC220, 0xC230, 0xC240, 0xC250, 0xC260, 0xC270,
|
||||
0xC280, 0xC290, 0xC2A0, 0xC2B0, 0xC2C0, 0xC2D0, 0xC2E0, 0xC2F0,
|
||||
0xC300, 0xC310, 0xC320, 0xC330, 0xC340, 0xC350, 0xC360, 0xC370,
|
||||
0xC380, 0xC390, 0xC3A0, 0xC3B0, 0xC3C0, 0xC3D0, 0xC3E0, 0xFFF0,
|
||||
};
|
||||
const gptoss_bfloat16 bf16_value{.bits = fp8e4m3_to_fp32[fp8_value.bits]};
|
||||
return upcast<float>(bf16_value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline double upcast<double>(float fp32_value) {
|
||||
return static_cast<double>(fp32_value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline double upcast<double>(gptoss_bfloat16 bf16_value) {
|
||||
const float fp32_value = upcast<float>(bf16_value);
|
||||
return upcast<double>(fp32_value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline double upcast<double>(gptoss_float16 fp16_value) {
|
||||
const float fp32_value = upcast<float>(fp16_value);
|
||||
return upcast<double>(fp32_value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline double upcast<double>(gptoss_float8e4m3 fp8_value) {
|
||||
const float fp32_value = upcast<float>(fp8_value);
|
||||
return upcast<double>(fp32_value);
|
||||
}
|
||||
|
||||
} // namespace gptoss
|
||||
105
gpt_oss/metal/source/include/internal/kernel-args.h
Normal file
105
gpt_oss/metal/source/include/internal/kernel-args.h
Normal file
@@ -0,0 +1,105 @@
|
||||
#pragma once
|
||||
|
||||
#if !defined(__METAL_VERSION__)
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
struct gptoss_expert_prediction {
|
||||
uint32_t expert_id;
|
||||
float score;
|
||||
};
|
||||
|
||||
struct gptoss_topk_args {
|
||||
uint32_t num_vecs_per_token;
|
||||
};
|
||||
|
||||
struct gptoss_sdpa_args {
|
||||
uint32_t qkv_dim;
|
||||
uint32_t num_kv_tokens;
|
||||
uint32_t window;
|
||||
};
|
||||
|
||||
struct gptoss_u32_fill_random_args {
|
||||
uint64_t num_vecs_per_threadgroup;
|
||||
uint64_t num_vecs;
|
||||
uint64_t offset;
|
||||
uint64_t seed;
|
||||
};
|
||||
|
||||
struct gptoss_f32_fill_random_args {
|
||||
uint64_t num_vecs_per_threadgroup;
|
||||
uint64_t num_vecs;
|
||||
uint64_t offset;
|
||||
uint64_t seed;
|
||||
float scale;
|
||||
float bias;
|
||||
};
|
||||
|
||||
struct gptoss_accumulate_args {
|
||||
uint32_t num_vecs_per_expert;
|
||||
uint32_t num_vecs_per_threadgroup;
|
||||
uint32_t num_vecs;
|
||||
};
|
||||
|
||||
struct gptoss_convert_args {
|
||||
uint64_t num_vecs_per_threadgroup;
|
||||
uint64_t num_vecs;
|
||||
};
|
||||
|
||||
struct gptoss_embeddings_args {
|
||||
uint32_t num_vecs;
|
||||
};
|
||||
|
||||
struct gptoss_rmsnorm_args {
|
||||
uint32_t num_vecs;
|
||||
float num_channels;
|
||||
float epsilon;
|
||||
};
|
||||
|
||||
struct gptoss_matmul_args {
|
||||
uint32_t num_column_vecs;
|
||||
uint32_t num_rows;
|
||||
uint32_t add;
|
||||
};
|
||||
|
||||
struct gptoss_unembedding_args {
|
||||
uint32_t num_column_vecs;
|
||||
uint32_t num_rows_per_threadgroup;
|
||||
uint32_t num_rows;
|
||||
};
|
||||
|
||||
struct gptoss_moe_matmul_swiglu_args {
|
||||
uint32_t num_column_vecs;
|
||||
uint32_t num_rows;
|
||||
uint32_t num_active_experts;
|
||||
uint32_t weight_expert_stride; // in bytes
|
||||
uint32_t output_expert_stride; // in elements
|
||||
float swiglu_min;
|
||||
float swiglu_max;
|
||||
};
|
||||
|
||||
struct gptoss_moe_matmul_args {
|
||||
uint32_t num_column_vecs;
|
||||
uint32_t num_rows;
|
||||
uint32_t num_active_experts;
|
||||
uint32_t input_expert_stride; // in blocks of 32 elements
|
||||
uint32_t weight_expert_stride; // in bytes
|
||||
uint32_t output_expert_stride; // in elements
|
||||
};
|
||||
|
||||
struct gptoss_rope_args {
|
||||
uint32_t token_stride;
|
||||
uint32_t token_offset;
|
||||
float freq_scale;
|
||||
float interpolation_scale;
|
||||
float yarn_offset;
|
||||
float yarn_scale;
|
||||
float yarn_multiplier;
|
||||
};
|
||||
|
||||
struct gptoss_softmax_args {
|
||||
uint32_t num_vecs;
|
||||
uint32_t num_vecs_per_threadgroup;
|
||||
uint32_t max_threadgroups;
|
||||
float temperature;
|
||||
};
|
||||
20
gpt_oss/metal/source/include/internal/log.h
Normal file
20
gpt_oss/metal/source/include/internal/log.h
Normal file
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdarg.h>
|
||||
|
||||
|
||||
void gptoss_format_log(const char* format, va_list args);
|
||||
|
||||
__attribute__((__format__(__printf__, 1, 2)))
|
||||
inline static void gptoss_log(const char* format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
gptoss_format_log(format, args);
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
#define GPTOSS_LOG_ERROR(message, ...) \
|
||||
gptoss_log("Error: " message "\n", ##__VA_ARGS__)
|
||||
|
||||
#define GPTOSS_LOG_WARNING(message, ...) \
|
||||
gptoss_log("Warning: " message "\n", ##__VA_ARGS__)
|
||||
107
gpt_oss/metal/source/include/internal/macros.h
Normal file
107
gpt_oss/metal/source/include/internal/macros.h
Normal file
@@ -0,0 +1,107 @@
|
||||
#pragma once
|
||||
|
||||
/***** Architecture detection macros *****/
|
||||
|
||||
#ifdef GPTOSS_ARCH_X86_64
|
||||
#if GPTOSS_ARCH_X86_64 != 0 && GPTOSS_ARCH_X86_64 != 1
|
||||
#error "Invalid GPTOSS_ARCH_X86_64 value: must be either 0 or 1"
|
||||
#endif
|
||||
#else
|
||||
#if defined(__x86_64__) || defined(_M_X64) && !defined(_M_ARM64EC)
|
||||
#define GPTOSS_ARCH_X86_64 1
|
||||
#else
|
||||
#define GPTOSS_ARCH_X86_64 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef GPTOSS_ARCH_ARM64
|
||||
#if GPTOSS_ARCH_ARM64 != 0 && GPTOSS_ARCH_ARM64 != 1
|
||||
#error "Invalid GPTOSS_ARCH_ARM64 value: must be either 0 or 1"
|
||||
#endif
|
||||
#else
|
||||
#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC)
|
||||
#define GPTOSS_ARCH_ARM64 1
|
||||
#else
|
||||
#define GPTOSS_ARCH_ARM64 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if GPTOSS_ARCH_X86_64 + GPTOSS_ARCH_ARM64 == 0
|
||||
#error "Unsupported architecture: neither x86-64 nor ARM64 detected"
|
||||
#elif GPTOSS_ARCH_X86_64 + GPTOSS_ARCH_ARM64 != 1
|
||||
#error "Inconsistent architecture detection: both x86-64 and ARM64 detection macros are specified"
|
||||
#endif
|
||||
|
||||
/***** Compiler portability macros *****/
|
||||
|
||||
#ifndef GPTOSS_LIKELY
|
||||
#if defined(__GNUC__)
|
||||
#define GPTOSS_LIKELY(condition) (__builtin_expect(!!(condition), 1))
|
||||
#else
|
||||
#define GPTOSS_LIKELY(condition) (!!(condition))
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef GPTOSS_UNLIKELY
|
||||
#if defined(__GNUC__)
|
||||
#define GPTOSS_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))
|
||||
#else
|
||||
#define GPTOSS_UNLIKELY(condition) (!!(condition))
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef GPTOSS_UNPREDICTABLE
|
||||
#if defined(__has_builtin)
|
||||
#if __has_builtin(__builtin_unpredictable)
|
||||
#define GPTOSS_UNPREDICTABLE(condition) (__builtin_unpredictable(!!(condition)))
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
#ifndef GPTOSS_UNPREDICTABLE
|
||||
#if defined(__GNUC__) && (__GNUC__ >= 9) && !defined(__INTEL_COMPILER)
|
||||
#define GPTOSS_UNPREDICTABLE(condition) (__builtin_expect_with_probability(!!(condition), 0, 0.5))
|
||||
#else
|
||||
#define GPTOSS_UNPREDICTABLE(condition) (!!(condition))
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Disable padding for structure members.
|
||||
#ifndef GPTOSS_DENSELY_PACKED_STRUCTURE
|
||||
#if defined(__GNUC__)
|
||||
#define GPTOSS_DENSELY_PACKED_STRUCTURE __attribute__((__packed__))
|
||||
#else
|
||||
#error "Compiler-specific implementation of GPTOSS_DENSELY_PACKED_STRUCTURE required"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef GPTOSS_ALIGN
|
||||
#if defined(__GNUC__)
|
||||
#define GPTOSS_ALIGN(alignment) __attribute__((__aligned__(alignment)))
|
||||
#elif defined(_MSC_VER)
|
||||
#define GPTOSS_ALIGN(alignment) __declspec(align(alignment))
|
||||
#else
|
||||
#error "Compiler-specific implementation of GPTOSS_ALIGN required"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef GPTOSS_FORCE_INLINE
|
||||
#if defined(__GNUC__)
|
||||
#define GPTOSS_FORCE_INLINE inline __attribute__((__always_inline__))
|
||||
#elif defined(_MSC_VER)
|
||||
#define GPTOSS_FORCE_INLINE __forceinline
|
||||
#else
|
||||
#define GPTOSS_FORCE_INLINE inline
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/***** Symbol visibility macros *****/
|
||||
|
||||
#ifndef GPTOSS_INTERNAL_SYMBOL
|
||||
#if defined(__ELF__)
|
||||
#define GPTOSS_INTERNAL_SYMBOL __attribute__((__visibility__("internal")))
|
||||
#elif defined(__MACH__)
|
||||
#define GPTOSS_INTERNAL_SYMBOL __attribute__((__visibility__("hidden")))
|
||||
#else
|
||||
#define GPTOSS_INTERNAL_SYMBOL
|
||||
#endif
|
||||
#endif
|
||||
25
gpt_oss/metal/source/include/internal/math.h
Normal file
25
gpt_oss/metal/source/include/internal/math.h
Normal file
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
inline static size_t math_ceil_div(size_t numer, size_t denom) {
|
||||
return (numer + denom - 1) / denom;
|
||||
}
|
||||
|
||||
inline static size_t math_max(size_t a, size_t b) {
|
||||
return a >= b ? a : b;
|
||||
}
|
||||
|
||||
inline static size_t math_min(size_t a, size_t b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
static size_t math_round_up_po2(size_t bytes, size_t multiple) {
|
||||
const size_t multiple_mask = multiple - 1;
|
||||
if ((bytes & multiple_mask) != 0) {
|
||||
bytes |= multiple_mask;
|
||||
bytes += 1;
|
||||
}
|
||||
return bytes;
|
||||
}
|
||||
270
gpt_oss/metal/source/include/internal/metal-kernels.h
Normal file
270
gpt_oss/metal/source/include/internal/metal-kernels.h
Normal file
@@ -0,0 +1,270 @@
|
||||
#pragma once
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <internal/metal.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
#include <internal/math.h>
|
||||
#include <internal/metal.h>
|
||||
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* u32_fill_random_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint64_t num_elements,
|
||||
uint64_t rng_seed,
|
||||
uint64_t rng_offset);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_fill_random_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint64_t num_elements,
|
||||
uint64_t rng_seed,
|
||||
uint64_t rng_offset,
|
||||
float rng_min,
|
||||
float rng_max);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* bf16_fill_random_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint64_t num_elements,
|
||||
uint64_t rng_seed,
|
||||
uint64_t rng_offset,
|
||||
float rng_min,
|
||||
float rng_max);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* mf4_f32_convert_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* block_buffer,
|
||||
const struct gptoss_metal_buffer* scale_buffer,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
uint64_t num_elements);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* bf16_f32_embeddings_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* token_buffer,
|
||||
size_t token_offset,
|
||||
const struct gptoss_metal_buffer* weight_buffer,
|
||||
size_t weight_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_channels);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_bf16w_rmsnorm_fn,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* weight_buffer,
|
||||
size_t weight_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_channels,
|
||||
float epsilon);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* weight_buffer,
|
||||
size_t weight_offset,
|
||||
const struct gptoss_metal_buffer* bias_buffer,
|
||||
size_t bias_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_cols,
|
||||
uint32_t num_rows);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* weight_buffer,
|
||||
size_t weight_offset,
|
||||
const struct gptoss_metal_buffer* bias_buffer,
|
||||
size_t bias_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_cols,
|
||||
uint32_t num_rows);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* weight_buffer,
|
||||
size_t weight_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
const struct gptoss_metal_buffer* argmax_buffer,
|
||||
size_t argmax_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_cols,
|
||||
uint32_t num_rows);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_mf4w_moe_matmul_swiglu_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* expert_buffer,
|
||||
size_t expert_offset,
|
||||
const struct gptoss_metal_buffer* weight_block_buffer,
|
||||
size_t weight_block_offset,
|
||||
const struct gptoss_metal_buffer* weight_scale_buffer,
|
||||
size_t weight_scale_offset,
|
||||
const struct gptoss_metal_buffer* bias_buffer,
|
||||
size_t bias_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
float swiglu_limit,
|
||||
uint32_t expert_stride,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_active_experts,
|
||||
uint32_t num_cols,
|
||||
uint32_t num_rows);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_mf4w_moe_matmul_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* expert_buffer,
|
||||
size_t expert_offset,
|
||||
const struct gptoss_metal_buffer* weight_block_buffer,
|
||||
size_t weight_block_offset,
|
||||
const struct gptoss_metal_buffer* weight_scale_buffer,
|
||||
size_t weight_scale_offset,
|
||||
const struct gptoss_metal_buffer* bias_buffer,
|
||||
size_t bias_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t expert_stride,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_active_experts,
|
||||
uint32_t num_cols,
|
||||
uint32_t num_rows);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_rope_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* activations_buffer,
|
||||
float rope_base,
|
||||
float interpolation_scale,
|
||||
float yarn_offset,
|
||||
float yarn_scale,
|
||||
float yarn_multiplier,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_q_heads,
|
||||
uint32_t num_kv_heads,
|
||||
uint32_t attn_head_dim,
|
||||
uint32_t token_offset);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_accumulate_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* expert_buffer,
|
||||
size_t expert_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_channels,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_experts);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_topk_fn,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_experts,
|
||||
uint32_t num_active_experts);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_sdpa_fn,
|
||||
const struct gptoss_metal_buffer* q_buffer,
|
||||
size_t q_offset,
|
||||
const struct gptoss_metal_buffer* k_buffer,
|
||||
size_t k_offset,
|
||||
const struct gptoss_metal_buffer* v_buffer,
|
||||
size_t v_offset,
|
||||
const struct gptoss_metal_buffer* s_buffer,
|
||||
size_t s_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t window,
|
||||
uint32_t num_q_tokens,
|
||||
uint32_t num_kv_tokens,
|
||||
uint32_t num_q_heads,
|
||||
uint32_t num_kv_heads,
|
||||
uint32_t head_dim);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_softmax_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* score_buffer,
|
||||
size_t score_offset,
|
||||
const struct gptoss_metal_buffer* argmax_buffer,
|
||||
size_t argmax_offset,
|
||||
const struct gptoss_metal_buffer* prob_buffer,
|
||||
size_t prob_offset,
|
||||
const struct gptoss_metal_buffer* sum_buffer,
|
||||
size_t sum_offset,
|
||||
uint32_t num_channels,
|
||||
uint32_t num_tokens,
|
||||
float temperature,
|
||||
uint32_t* num_threadgroups_out,
|
||||
uint32_t* num_channels_per_threadgroup_out);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
137
gpt_oss/metal/source/include/internal/metal.h
Normal file
137
gpt_oss/metal/source/include/internal/metal.h
Normal file
@@ -0,0 +1,137 @@
|
||||
#pragma once
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <gpt-oss/types.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct gptoss_metal_device {
|
||||
void* object; // id<MTLDevice>
|
||||
size_t num_cores;
|
||||
size_t max_buffer_size;
|
||||
size_t max_threadgroup_memory;
|
||||
size_t max_threadgroup_threads_x;
|
||||
size_t max_threadgroup_threads_y;
|
||||
size_t max_threadgroup_threads_z;
|
||||
};
|
||||
|
||||
enum gptoss_status gptoss_metal_device_create_system_default(
|
||||
struct gptoss_metal_device* device_out);
|
||||
|
||||
enum gptoss_status gptoss_metal_device_release(
|
||||
struct gptoss_metal_device* device);
|
||||
|
||||
|
||||
struct gptoss_metal_library {
|
||||
void* object; // id<MTLLibrary>
|
||||
};
|
||||
|
||||
enum gptoss_status gptoss_metal_library_create_default(
|
||||
const struct gptoss_metal_device* device,
|
||||
struct gptoss_metal_library* library_out);
|
||||
|
||||
enum gptoss_status gptoss_metal_library_release(
|
||||
struct gptoss_metal_library* library);
|
||||
|
||||
struct gptoss_metal_function {
|
||||
void* function_object; // id<MTLFunction>
|
||||
void* pipeline_state_object; // id<MTLComputePipelineState>
|
||||
size_t max_threadgroup_threads;
|
||||
size_t simdgroup_threads;
|
||||
size_t static_threadgroup_memory;
|
||||
};
|
||||
|
||||
enum gptoss_status gptoss_metal_function_create(
|
||||
const struct gptoss_metal_library* library,
|
||||
const char* name,
|
||||
struct gptoss_metal_function* function_out);
|
||||
|
||||
enum gptoss_status gptoss_metal_function_release(
|
||||
struct gptoss_metal_function* function);
|
||||
|
||||
struct gptoss_metal_buffer {
|
||||
void* object; // id<MTLBuffer>
|
||||
size_t size;
|
||||
void* ptr;
|
||||
};
|
||||
|
||||
enum gptoss_status gptoss_metal_buffer_create(
|
||||
const struct gptoss_metal_device* device,
|
||||
size_t size,
|
||||
const void* data,
|
||||
struct gptoss_metal_buffer* buffer_out);
|
||||
|
||||
enum gptoss_status gptoss_metal_buffer_wrap(
|
||||
const struct gptoss_metal_device* device,
|
||||
size_t size,
|
||||
const void* data,
|
||||
struct gptoss_metal_buffer* buffer_out);
|
||||
|
||||
enum gptoss_status gptoss_metal_buffer_release(
|
||||
struct gptoss_metal_buffer* buffer);
|
||||
|
||||
struct gptoss_metal_command_queue {
|
||||
void* object; // id<MTLCommandQueue>
|
||||
};
|
||||
|
||||
enum gptoss_status gptoss_metal_command_queue_create(
|
||||
const struct gptoss_metal_device* device,
|
||||
struct gptoss_metal_command_queue* command_queue_out);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_queue_release(
|
||||
struct gptoss_metal_command_queue* command_queue);
|
||||
|
||||
struct gptoss_metal_command_buffer {
|
||||
void* object; // id<MTLCommandBuffer>
|
||||
};
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_create(
|
||||
const struct gptoss_metal_command_queue* command_queue,
|
||||
struct gptoss_metal_command_buffer* command_buffer_out);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_fill_buffer(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_buffer* buffer,
|
||||
size_t offset,
|
||||
size_t size,
|
||||
uint8_t fill_value);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_copy_buffer(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
size_t size);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* function,
|
||||
size_t threadgroup_size_x,
|
||||
size_t threadgroup_size_y,
|
||||
size_t threadgroup_size_z,
|
||||
size_t num_threadgroups_x,
|
||||
size_t num_threadgroups_y,
|
||||
size_t num_threadgroups_z,
|
||||
size_t params_size,
|
||||
const void* params,
|
||||
size_t num_buffers,
|
||||
const struct gptoss_metal_buffer** buffers,
|
||||
const size_t* buffer_offsets);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_commit(
|
||||
const struct gptoss_metal_command_buffer* command_buffer);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_wait_completion(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
double* elapsed_seconds);
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_release(
|
||||
struct gptoss_metal_command_buffer* command_buffer);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
340
gpt_oss/metal/source/include/internal/metal.hpp
Normal file
340
gpt_oss/metal/source/include/internal/metal.hpp
Normal file
@@ -0,0 +1,340 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <initializer_list>
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include <gpt-oss/types.h>
|
||||
#include <internal/metal.h>
|
||||
#include <internal/metal-kernels.h>
|
||||
|
||||
|
||||
namespace gptoss {
|
||||
|
||||
inline void Check(gptoss_status s, const char* what) {
|
||||
if (s != gptoss_status_success) {
|
||||
throw std::runtime_error(what);
|
||||
}
|
||||
}
|
||||
|
||||
inline std::size_t round_up(std::size_t p, std::size_t q) {
|
||||
const std::size_t r = p % q;
|
||||
if (r == 0) {
|
||||
return p;
|
||||
} else {
|
||||
return p - r + q;
|
||||
}
|
||||
}
|
||||
|
||||
namespace metal {
|
||||
|
||||
class Device {
|
||||
public:
|
||||
inline Device() {
|
||||
Check(gptoss_metal_device_create_system_default(&device_), "create Device");
|
||||
}
|
||||
|
||||
inline ~Device() {
|
||||
gptoss_metal_device_release(&device_);
|
||||
}
|
||||
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
|
||||
inline Device(Device&& other) noexcept {
|
||||
device_ = other.device_;
|
||||
std::memset(&other.device_, 0, sizeof(other.device_));
|
||||
}
|
||||
|
||||
inline Device& operator=(Device&& other) noexcept {
|
||||
if (this != &other) {
|
||||
gptoss_metal_device_release(&device_);
|
||||
device_ = other.device_;
|
||||
std::memset(&other.device_, 0, sizeof(other.device_));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline const gptoss_metal_device* handle() const noexcept { return &device_; }
|
||||
|
||||
inline size_t max_buffer_size() const noexcept { return device_.max_buffer_size; }
|
||||
inline size_t max_threadgroup_memory() const noexcept { return device_.max_threadgroup_memory; }
|
||||
inline size_t max_threadgroup_threads_x() const noexcept { return device_.max_threadgroup_threads_x; }
|
||||
inline size_t max_threadgroup_threads_y() const noexcept { return device_.max_threadgroup_threads_y; }
|
||||
inline size_t max_threadgroup_threads_z() const noexcept { return device_.max_threadgroup_threads_z; }
|
||||
|
||||
private:
|
||||
gptoss_metal_device device_{};
|
||||
};
|
||||
|
||||
class Library {
|
||||
public:
|
||||
inline explicit Library(const Device& dev) {
|
||||
Check(gptoss_metal_library_create_default(dev.handle(), &library_),
|
||||
"gptoss_metal_library_create_default");
|
||||
}
|
||||
|
||||
inline ~Library() {
|
||||
gptoss_metal_library_release(&library_);
|
||||
}
|
||||
|
||||
Library(const Library&) = delete;
|
||||
Library& operator=(const Library&) = delete;
|
||||
|
||||
inline Library(Library&& other) noexcept {
|
||||
library_ = other.library_;
|
||||
std::memset(&other.library_, 0, sizeof(other.library_));
|
||||
}
|
||||
|
||||
inline Library& operator=(Library&& other) noexcept {
|
||||
if (this != &other) {
|
||||
gptoss_metal_library_release(&library_);
|
||||
library_ = other.library_;
|
||||
std::memset(&other.library_, 0, sizeof(other.library_));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline const gptoss_metal_library* handle() const noexcept {
|
||||
return &library_;
|
||||
}
|
||||
|
||||
private:
|
||||
gptoss_metal_library library_{};
|
||||
};
|
||||
|
||||
class Function {
|
||||
public:
|
||||
inline Function(const Library& library, const char* name) {
|
||||
Check(gptoss_metal_function_create(library.handle(), name, &function_),
|
||||
"gptoss_metal_function_create");
|
||||
}
|
||||
|
||||
inline ~Function() {
|
||||
gptoss_metal_function_release(&function_);
|
||||
}
|
||||
|
||||
Function(const Function&) = delete;
|
||||
Function& operator=(const Function&) = delete;
|
||||
|
||||
inline Function(Function&& other) noexcept {
|
||||
function_ = other.function_;
|
||||
std::memset(&other.function_, 0, sizeof(other.function_));
|
||||
}
|
||||
|
||||
inline Function& operator=(Function&& other) noexcept {
|
||||
if (this != &other) {
|
||||
gptoss_metal_function_release(&function_);
|
||||
function_ = other.function_;
|
||||
std::memset(&other.function_, 0, sizeof(other.function_));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline const gptoss_metal_function* handle() const noexcept { return &function_; }
|
||||
|
||||
inline size_t max_threadgroup_threads() const noexcept { return function_.max_threadgroup_threads; }
|
||||
inline size_t simdgroup_threads() const noexcept { return function_.simdgroup_threads; }
|
||||
inline size_t static_threadgroup_memory() const noexcept { return function_.static_threadgroup_memory; }
|
||||
|
||||
private:
|
||||
gptoss_metal_function function_{};
|
||||
};
|
||||
|
||||
class Buffer {
|
||||
public:
|
||||
inline Buffer(const Device& dev, size_t size, const void* data = nullptr) {
|
||||
Check(gptoss_metal_buffer_create(dev.handle(), size, data, &buffer_), "create buffer");
|
||||
}
|
||||
|
||||
inline ~Buffer() {
|
||||
gptoss_metal_buffer_release(&buffer_);
|
||||
}
|
||||
|
||||
Buffer(const Buffer&) = delete;
|
||||
Buffer& operator=(const Buffer&) = delete;
|
||||
|
||||
inline Buffer(Buffer&& other) noexcept {
|
||||
buffer_ = other.buffer_;
|
||||
std::memset(&other.buffer_, 0, sizeof(other.buffer_));
|
||||
}
|
||||
|
||||
inline Buffer& operator=(Buffer&& other) noexcept {
|
||||
if (this != &other) {
|
||||
gptoss_metal_buffer_release(&buffer_);
|
||||
buffer_ = other.buffer_;
|
||||
std::memset(&other.buffer_, 0, sizeof(other.buffer_));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline size_t size() const noexcept { return buffer_.size; }
|
||||
inline void* ptr() const noexcept { return buffer_.ptr; }
|
||||
|
||||
inline const gptoss_metal_buffer* handle() const noexcept { return &buffer_; }
|
||||
|
||||
private:
|
||||
gptoss_metal_buffer buffer_{};
|
||||
};
|
||||
|
||||
class CommandQueue {
|
||||
public:
|
||||
inline explicit CommandQueue(const Device& dev) {
|
||||
Check(gptoss_metal_command_queue_create(dev.handle(), &command_queue_),
|
||||
"gptoss_metal_command_queue_create");
|
||||
}
|
||||
|
||||
inline ~CommandQueue() {
|
||||
gptoss_metal_command_queue_release(&command_queue_);
|
||||
}
|
||||
|
||||
CommandQueue(const CommandQueue&) = delete;
|
||||
CommandQueue& operator=(const CommandQueue&) = delete;
|
||||
|
||||
inline CommandQueue(CommandQueue&& other) noexcept {
|
||||
command_queue_ = other.command_queue_;
|
||||
std::memset(&other.command_queue_, 0, sizeof(other.command_queue_));
|
||||
}
|
||||
|
||||
inline CommandQueue& operator=(CommandQueue&& other) noexcept {
|
||||
if (this != &other) {
|
||||
gptoss_metal_command_queue_release(&command_queue_);
|
||||
command_queue_ = other.command_queue_;
|
||||
std::memset(&other.command_queue_, 0, sizeof(other.command_queue_));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline const gptoss_metal_command_queue* handle() const noexcept {
|
||||
return &command_queue_;
|
||||
}
|
||||
|
||||
private:
|
||||
gptoss_metal_command_queue command_queue_{};
|
||||
};
|
||||
|
||||
class CommandBuffer {
|
||||
public:
|
||||
inline explicit CommandBuffer(const CommandQueue& command_queue) {
|
||||
Check(gptoss_metal_command_buffer_create(command_queue.handle(), &command_buffer_),
|
||||
"gptoss_metal_command_buffer_create");
|
||||
}
|
||||
inline ~CommandBuffer() {
|
||||
gptoss_metal_command_buffer_release(&command_buffer_);
|
||||
}
|
||||
|
||||
CommandBuffer(const CommandBuffer&) = delete;
|
||||
CommandBuffer& operator=(const CommandBuffer&) = delete;
|
||||
|
||||
inline CommandBuffer(CommandBuffer&& other) noexcept {
|
||||
command_buffer_ = other.command_buffer_;
|
||||
std::memset(&other.command_buffer_, 0, sizeof(other.command_buffer_));
|
||||
}
|
||||
|
||||
inline CommandBuffer& operator=(CommandBuffer&& other) noexcept {
|
||||
if (this != &other) {
|
||||
gptoss_metal_command_buffer_release(&command_buffer_);
|
||||
command_buffer_ = other.command_buffer_;
|
||||
std::memset(&other.command_buffer_, 0, sizeof(other.command_buffer_));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline void encode_launch_kernel(const Function& function,
|
||||
const std::array<size_t, 3>& threadgroup_size,
|
||||
const std::array<size_t, 3>& num_threadgroups,
|
||||
size_t params_size, const void* params,
|
||||
std::initializer_list<const Buffer*> buffers = {})
|
||||
{
|
||||
std::vector<const gptoss_metal_buffer*> buffer_handles(buffers.size());
|
||||
std::transform(buffers.begin(), buffers.end(), buffer_handles.begin(),
|
||||
[](const Buffer* buffer) -> const gptoss_metal_buffer* { return buffer->handle(); });
|
||||
Check(gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
&command_buffer_, function.handle(),
|
||||
threadgroup_size[0], threadgroup_size[1], threadgroup_size[2],
|
||||
num_threadgroups[0], num_threadgroups[1], num_threadgroups[2],
|
||||
params_size, params,
|
||||
buffer_handles.size(),
|
||||
buffer_handles.data(),
|
||||
/*buffer_offsets=*/nullptr),
|
||||
"gptoss_metal_command_buffer_encode_launch_kernel");
|
||||
}
|
||||
|
||||
inline void encode_launch_f32_fill_random(const Function& f32_fill_random_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t num_threadgroups,
|
||||
const Buffer& output_buffer,
|
||||
size_t output_offset,
|
||||
size_t num_channels,
|
||||
uint64_t rng_seed,
|
||||
uint64_t rng_offset,
|
||||
float rng_min,
|
||||
float rng_max)
|
||||
{
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
||||
&command_buffer_, f32_fill_random_fn.handle(),
|
||||
threadgroup_size, num_threadgroups,
|
||||
output_buffer.handle(), output_offset,
|
||||
num_channels,
|
||||
rng_seed, rng_offset, rng_min, rng_max),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
|
||||
}
|
||||
|
||||
inline void encode_launch_bf16_fill_random(const Function& bf16_fill_random_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t num_threadgroups,
|
||||
const Buffer& output_buffer,
|
||||
size_t output_offset,
|
||||
size_t num_channels,
|
||||
uint64_t rng_seed,
|
||||
uint64_t rng_offset,
|
||||
float rng_min,
|
||||
float rng_max)
|
||||
{
|
||||
Check(gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
|
||||
&command_buffer_, bf16_fill_random_fn.handle(),
|
||||
threadgroup_size, num_threadgroups,
|
||||
output_buffer.handle(), output_offset,
|
||||
num_channels,
|
||||
rng_seed, rng_offset, rng_min, rng_max),
|
||||
"gptoss_metal_command_buffer_encode_launch_bf16_fill_random");
|
||||
}
|
||||
|
||||
inline void encode_launch_u32_fill_random(const Function& u32_fill_random_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t num_threadgroups,
|
||||
const Buffer& output_buffer,
|
||||
size_t output_offset,
|
||||
size_t num_channels,
|
||||
uint64_t rng_seed,
|
||||
uint64_t rng_offset)
|
||||
{
|
||||
Check(gptoss_metal_command_buffer_encode_launch_u32_fill_random(
|
||||
&command_buffer_, u32_fill_random_fn.handle(),
|
||||
threadgroup_size, num_threadgroups,
|
||||
output_buffer.handle(), output_offset,
|
||||
num_channels,
|
||||
rng_seed, rng_offset),
|
||||
"gptoss_metal_command_buffer_encode_launch_u32_fill_random");
|
||||
}
|
||||
|
||||
inline void commit() {
|
||||
Check(gptoss_metal_command_buffer_commit(&command_buffer_), "commit");
|
||||
}
|
||||
|
||||
inline double wait_completion() {
|
||||
double secs = 0.0;
|
||||
Check(gptoss_metal_command_buffer_wait_completion(&command_buffer_, &secs), "wait completion");
|
||||
return secs;
|
||||
}
|
||||
|
||||
inline const gptoss_metal_command_buffer* handle() const noexcept { return &command_buffer_; }
|
||||
|
||||
private:
|
||||
gptoss_metal_command_buffer command_buffer_{};
|
||||
};
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gptoss
|
||||
153
gpt_oss/metal/source/include/internal/model.h
Normal file
153
gpt_oss/metal/source/include/internal/model.h
Normal file
@@ -0,0 +1,153 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdatomic.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "internal/metal.h"
|
||||
|
||||
|
||||
struct gptoss_tokenizer {
|
||||
atomic_uint_least64_t ref_count;
|
||||
|
||||
void* mapping_ptr;
|
||||
size_t mapping_size;
|
||||
|
||||
const char* regex_ptr;
|
||||
const char* tokens_ptr;
|
||||
|
||||
uint32_t num_text_tokens;
|
||||
uint32_t num_special_tokens;
|
||||
|
||||
uint32_t special_token_id[gptoss_special_token_max - 1];
|
||||
};
|
||||
|
||||
struct gptoss_model {
|
||||
atomic_uint_least64_t ref_count;
|
||||
|
||||
struct gptoss_tokenizer* tokenizer;
|
||||
|
||||
void* mapping_ptr;
|
||||
size_t mapping_size;
|
||||
|
||||
uint32_t context_length;
|
||||
uint32_t num_blocks;
|
||||
uint32_t num_experts;
|
||||
uint32_t num_active_experts;
|
||||
uint32_t embedding_dim;
|
||||
uint32_t mlp_dim;
|
||||
float swiglu_limit;
|
||||
uint32_t head_dim;
|
||||
uint32_t num_heads;
|
||||
uint32_t num_kv_heads;
|
||||
uint32_t attention_window;
|
||||
float rope_theta;
|
||||
float interpolation_scale;
|
||||
float yarn_offset;
|
||||
float yarn_scale;
|
||||
float yarn_multiplier;
|
||||
float rmsnorm_epsilon;
|
||||
|
||||
uint32_t vocabulary_size;
|
||||
|
||||
// Maximum number of tokens that can be processed in a single batch.
|
||||
// Once the batch size is reached, we process it to fill the KV cache.
|
||||
size_t max_batch_tokens;
|
||||
|
||||
size_t weights_size;
|
||||
size_t allocation_size;
|
||||
|
||||
// Metal objects
|
||||
struct gptoss_metal_device device;
|
||||
size_t max_threadgroups;
|
||||
struct gptoss_metal_command_queue command_queue;
|
||||
struct gptoss_metal_library library;
|
||||
struct gptoss_metal_function bf16_f32_embeddings_fn;
|
||||
struct gptoss_metal_function f32_bf16w_rmsnorm_fn;
|
||||
struct gptoss_metal_function f32_bf16w_matmul_fn;
|
||||
struct gptoss_metal_function f32_bf16w_unembedding_fn;
|
||||
struct gptoss_metal_function f32_rope_fn;
|
||||
struct gptoss_metal_function f32_mf4w_moe_matmul_swiglu_fn;
|
||||
struct gptoss_metal_function f32_mf4w_moe_matmul_fn;
|
||||
struct gptoss_metal_function f32_accumulate_e4_fn;
|
||||
struct gptoss_metal_function f32_topk_softmax_e32_k4_fn;
|
||||
struct gptoss_metal_function f32_topk_softmax_e128_k4_fn;
|
||||
struct gptoss_metal_function f32_sdpa_q8_d64_fn;
|
||||
struct gptoss_metal_function f32_softmax_fn;
|
||||
|
||||
// Activation buffers.
|
||||
// TODO: merge into a single buffer.
|
||||
struct gptoss_metal_buffer residual_activation_buffer; // Residual stream
|
||||
struct gptoss_metal_buffer rmsnorm_activation_buffer; // Both attention & MLP RMSNorm output
|
||||
struct gptoss_metal_buffer qkv_activation_buffer; // QKV projection output
|
||||
struct gptoss_metal_buffer sdpa_activation_buffer; // SDPA output
|
||||
struct gptoss_metal_buffer gate_activation_buffer; // MoE gating output
|
||||
struct gptoss_metal_buffer expert_activation_buffer; // MoE expert predictions
|
||||
struct gptoss_metal_buffer swiglu_activation_buffer; // MLP+SwiGLU output
|
||||
struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert)
|
||||
|
||||
size_t per_block_shared_weights_size;
|
||||
size_t per_expert_block_weight_size;
|
||||
|
||||
size_t attn_rmsnorm_gain_offset;
|
||||
size_t attn_qkv_weight_offset;
|
||||
size_t attn_qkv_bias_offset;
|
||||
size_t attn_sdpa_sink_offset;
|
||||
size_t attn_out_weight_offset;
|
||||
size_t attn_out_bias_offset;
|
||||
size_t mlp_rmsnorm_gain_offset;
|
||||
size_t mlp_gate_weight_offset;
|
||||
size_t mlp_gate_bias_offset;
|
||||
size_t mlp_swiglu_scale_offset;
|
||||
size_t mlp_swiglu_bias_offset;
|
||||
size_t mlp_out_block_offset;
|
||||
size_t mlp_out_scale_offset;
|
||||
size_t mlp_out_bias_offset;
|
||||
size_t rmsnorm_weight_offset;
|
||||
size_t unembedding_weight_offset;
|
||||
|
||||
// Buffer with non-MoE weights. Includes MoE gates, embeddings/unembeddings.
|
||||
struct gptoss_metal_buffer shared_weight_buffer;
|
||||
// num_blocks per-block buffers with MoE weights to follow.
|
||||
struct gptoss_metal_buffer block_weight_buffers[];
|
||||
};
|
||||
|
||||
#define GPTOSS_DEFAULT_BATCH_SIZE 128
|
||||
|
||||
struct gptoss_context {
|
||||
atomic_uint_least64_t ref_count;
|
||||
|
||||
struct gptoss_model* model;
|
||||
// Number of tokens processed in the context.
|
||||
size_t num_tokens;
|
||||
// Number of tokens in the KV cache.
|
||||
size_t num_kv_tokens;
|
||||
// Length of the context.
|
||||
size_t max_tokens;
|
||||
|
||||
// Current number of tokens in the batch.
|
||||
// Always in the [0, max_batch_tokens) range.
|
||||
size_t num_batch_tokens;
|
||||
// Number of tokens processed in the last batch.
|
||||
// Activations for [num_batch_tokens, num_processed_tokens) tokens can be accessed from internal structures.
|
||||
size_t num_processed_tokens;
|
||||
|
||||
size_t kvcache_size;
|
||||
size_t allocation_size;
|
||||
|
||||
struct gptoss_metal_buffer token_buffer; // uint32 token IDs
|
||||
struct gptoss_metal_buffer score_buffer; // unembedding outputs
|
||||
struct gptoss_metal_buffer prob_buffer;
|
||||
struct gptoss_metal_buffer sum_buffer;
|
||||
struct gptoss_metal_buffer argmax_buffer;
|
||||
struct gptoss_metal_buffer kvcache_buffer;
|
||||
};
|
||||
|
||||
struct gptoss_sampler {
|
||||
atomic_uint_least64_t ref_count;
|
||||
|
||||
float temperature;
|
||||
float top_p;
|
||||
float presence_penalty;
|
||||
float frequency_penalty;
|
||||
};
|
||||
24
gpt_oss/metal/source/include/internal/rng.h
Normal file
24
gpt_oss/metal/source/include/internal/rng.h
Normal file
@@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
inline static uint32_t rng_squares32(uint64_t offset, uint64_t seed) {
|
||||
const uint64_t y = offset * seed;
|
||||
const uint64_t z = y + seed;
|
||||
|
||||
/* Round 1 */
|
||||
uint64_t x = y * y + y;
|
||||
x = (x >> 32) | (x << 32);
|
||||
|
||||
/* Round 2 */
|
||||
x = x * x + z;
|
||||
x = (x >> 32) | (x << 32);
|
||||
|
||||
/* Round 3 */
|
||||
x = x * x + y;
|
||||
x = (x >> 32) | (x << 32);
|
||||
|
||||
/* Round 4 */
|
||||
x = x * x + z;
|
||||
return (uint32_t) (x >> 32);
|
||||
}
|
||||
32
gpt_oss/metal/source/include/internal/rng.hpp
Normal file
32
gpt_oss/metal/source/include/internal/rng.hpp
Normal file
@@ -0,0 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace gptoss {
|
||||
|
||||
namespace rng {
|
||||
|
||||
inline static std::uint32_t squares32(std::uint64_t offset, std::uint64_t seed) {
|
||||
const std::uint64_t y = offset * seed;
|
||||
const std::uint64_t z = y + seed;
|
||||
|
||||
/* Round 1 */
|
||||
std::uint64_t x = y * y + y;
|
||||
x = (x >> 32) | (x << 32);
|
||||
|
||||
/* Round 2 */
|
||||
x = x * x + z;
|
||||
x = (x >> 32) | (x << 32);
|
||||
|
||||
/* Round 3 */
|
||||
x = x * x + y;
|
||||
x = (x >> 32) | (x << 32);
|
||||
|
||||
/* Round 4 */
|
||||
x = x * x + z;
|
||||
return static_cast<uint32_t>(x >> 32);
|
||||
}
|
||||
|
||||
} // namespace rng
|
||||
|
||||
} // namespace gptoss
|
||||
36
gpt_oss/metal/source/include/internal/storage.h
Normal file
36
gpt_oss/metal/source/include/internal/storage.h
Normal file
@@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
struct gptoss_file_header {
|
||||
char magic[12];
|
||||
uint32_t zero;
|
||||
};
|
||||
|
||||
struct gptoss_gptoss_model_header {
|
||||
uint32_t context_length;
|
||||
uint32_t num_blocks;
|
||||
uint32_t num_experts;
|
||||
uint32_t num_active_experts;
|
||||
uint32_t embedding_dim;
|
||||
uint32_t mlp_dim;
|
||||
float swiglu_limit;
|
||||
uint32_t head_dim;
|
||||
uint32_t num_heads;
|
||||
uint32_t num_kv_heads;
|
||||
uint32_t attention_window;
|
||||
float rope_theta;
|
||||
float interpolation_scale;
|
||||
float yarn_offset;
|
||||
float yarn_scale;
|
||||
float yarn_multiplier;
|
||||
float rmsnorm_epsilon;
|
||||
};
|
||||
|
||||
struct gptoss_tiktoken_tokenizer_header {
|
||||
uint32_t num_special_tokens;
|
||||
uint32_t num_text_tokens;
|
||||
uint32_t regex_size;
|
||||
uint32_t tokens_size;
|
||||
};
|
||||
114
gpt_oss/metal/source/include/internal/uuid.h
Normal file
114
gpt_oss/metal/source/include/internal/uuid.h
Normal file
@@ -0,0 +1,114 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "internal/macros.h"
|
||||
|
||||
|
||||
struct GPTOSS_DENSELY_PACKED_STRUCTURE gptoss_uuid {
|
||||
uint8_t bytes[16];
|
||||
};
|
||||
static_assert(sizeof(struct gptoss_uuid) == 16, "UUID size is not 16 bytes");
|
||||
|
||||
|
||||
#define UUID_FORMAT "%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X"
|
||||
#define UUID_ARGS(uuid) (uuid).bytes[0], (uuid).bytes[1], (uuid).bytes[2], (uuid).bytes[3], \
|
||||
(uuid).bytes[4], (uuid).bytes[5], (uuid).bytes[6], (uuid).bytes[7], (uuid).bytes[8], (uuid).bytes[9], \
|
||||
(uuid).bytes[10], (uuid).bytes[11], (uuid).bytes[12], (uuid).bytes[13], (uuid).bytes[14], (uuid).bytes[15]
|
||||
|
||||
static inline bool gptoss_is_gptoss_model_uuid(const struct gptoss_uuid* uuid) {
|
||||
return memcmp(
|
||||
&(struct gptoss_uuid) {0xDF, 0x52, 0xDC, 0x86, 0x17, 0x89, 0x4E, 0xD0, 0xA2, 0x95, 0x66, 0xF1, 0x05, 0x08, 0x14, 0x5B},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0;
|
||||
}
|
||||
|
||||
static inline bool gptoss_is_applegpu_layout_uuid(const struct gptoss_uuid* uuid) {
|
||||
return memcmp(
|
||||
&(struct gptoss_uuid) {0x22, 0x91, 0x77, 0xA8, 0x57, 0x75, 0x42, 0x68, 0xBF, 0xD8, 0xD5, 0x88, 0xB3, 0x51, 0xC5, 0x6D},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0;
|
||||
}
|
||||
|
||||
static inline bool gptoss_is_tiktoken_tokenizer_uuid(const struct gptoss_uuid* uuid) {
|
||||
return memcmp(
|
||||
&(struct gptoss_uuid) {0x74, 0x01, 0xAD, 0xED, 0x2A, 0x95, 0x40, 0xCB, 0xB7, 0x82, 0x9C, 0xCE, 0xBA, 0xAF, 0xE7, 0x2B},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0;
|
||||
}
|
||||
|
||||
static inline enum gptoss_special_token gptoss_special_token_decode_uuid(const struct gptoss_uuid* uuid) {
|
||||
if (memcmp(
|
||||
&(struct gptoss_uuid) {0x55, 0xA7, 0x7C, 0x2F, 0x8A, 0x01, 0x4C, 0x54, 0x8A, 0xC2, 0x31, 0x3B, 0xFC, 0x7E, 0x20, 0x8D},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0)
|
||||
{
|
||||
return gptoss_special_token_start;
|
||||
} else if (memcmp(
|
||||
&(struct gptoss_uuid) {0x16, 0xE4, 0x04, 0x31, 0xF4, 0x7F, 0x4B, 0x22, 0xB5, 0x9B, 0x8B, 0x27, 0x8F, 0xC3, 0x0A, 0x54},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0)
|
||||
{
|
||||
return gptoss_special_token_message;
|
||||
} else if (memcmp(
|
||||
&(struct gptoss_uuid) {0xFC, 0xAC, 0x2F, 0x6D, 0x47, 0x05, 0x4F, 0x6B, 0xB2, 0x28, 0x64, 0x2A, 0xCC, 0xAC, 0x72, 0x38},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0)
|
||||
{
|
||||
return gptoss_special_token_end;
|
||||
} else if (memcmp(
|
||||
&(struct gptoss_uuid) {0xF7, 0x99, 0xFF, 0x69, 0x19, 0x92, 0x43, 0xC4, 0xA3, 0xD8, 0xD8, 0x31, 0xF4, 0x75, 0xDC, 0x75},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0)
|
||||
{
|
||||
return gptoss_special_token_return;
|
||||
} else if (memcmp(
|
||||
&(struct gptoss_uuid) {0xE1, 0x5B, 0xA7, 0x02, 0x28, 0xC4, 0x42, 0x92, 0xAB, 0x8F, 0xFF, 0xA4, 0x34, 0x70, 0x91, 0x28},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0)
|
||||
{
|
||||
return gptoss_special_token_refusal;
|
||||
} else if (memcmp(
|
||||
&(struct gptoss_uuid) {0xC0, 0xBB, 0x14, 0xC7, 0x60, 0x22, 0x49, 0xDA, 0xAD, 0x08, 0x79, 0x2D, 0x67, 0xE8, 0xB4, 0x70},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0)
|
||||
{
|
||||
return gptoss_special_token_constrain;
|
||||
} else if (memcmp(
|
||||
&(struct gptoss_uuid) {0xFD, 0x3D, 0xDA, 0x11, 0xC8, 0xAB, 0x40, 0x33, 0x87, 0x6E, 0xD9, 0x3D, 0xEB, 0x17, 0x2C, 0x93},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0)
|
||||
{
|
||||
return gptoss_special_token_channel;
|
||||
} else if (memcmp(
|
||||
&(struct gptoss_uuid) {0x12, 0x20, 0xF7, 0x96, 0xE3, 0x88, 0x4D, 0xE5, 0xB4, 0x87, 0xFE, 0x2E, 0xB5, 0xFE, 0x03, 0xC0},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0)
|
||||
{
|
||||
return gptoss_special_token_call;
|
||||
} else if (memcmp(
|
||||
&(struct gptoss_uuid) {0x07, 0xD7, 0xDA, 0x55, 0xB3, 0x46, 0x4C, 0xFF, 0x8B, 0x37, 0x7C, 0xEF, 0xAC, 0xF8, 0xA3, 0xE8},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0)
|
||||
{
|
||||
return gptoss_special_token_untrusted;
|
||||
} else if (memcmp(
|
||||
&(struct gptoss_uuid) {0xF2, 0x65, 0xBD, 0x9C, 0xC7, 0x17, 0x46, 0x9E, 0xA4, 0x47, 0x92, 0x06, 0x87, 0xD6, 0x5D, 0x90},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0)
|
||||
{
|
||||
return gptoss_special_token_end_untrusted;
|
||||
} else if (memcmp(
|
||||
&(struct gptoss_uuid) {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
|
||||
uuid,
|
||||
sizeof(struct gptoss_uuid)) == 0)
|
||||
{
|
||||
// Suppress warning
|
||||
return gptoss_special_token_invalid;
|
||||
} else {
|
||||
GPTOSS_LOG_WARNING("unsupported special token " UUID_FORMAT, UUID_ARGS(*uuid));
|
||||
return gptoss_special_token_invalid;
|
||||
}
|
||||
}
|
||||
50
gpt_oss/metal/source/log.c
Normal file
50
gpt_oss/metal/source/log.c
Normal file
@@ -0,0 +1,50 @@
|
||||
#include <assert.h> // assert
|
||||
#include <stdarg.h> // va_list, va_copy, va_end
|
||||
#include <stdio.h> // vsnprintf
|
||||
#include <stdlib.h> // malloc, free
|
||||
|
||||
#include <unistd.h> // STDERR_FILENO
|
||||
|
||||
|
||||
|
||||
#define GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE 16384
|
||||
|
||||
void gptoss_format_log(const char* format, va_list args) {
|
||||
char stack_buffer[GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE];
|
||||
char* heap_buffer = NULL;
|
||||
|
||||
va_list args_copy;
|
||||
va_copy(args_copy, args);
|
||||
|
||||
const int vsnprintf_result = vsnprintf(stack_buffer, GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE, format, args);
|
||||
assert(vsnprintf_result >= 0);
|
||||
|
||||
// At least a partially formatted buffer is ready.
|
||||
char* message_buffer = &stack_buffer[0];
|
||||
size_t message_size = (size_t) vsnprintf_result;
|
||||
if (message_size > GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE) {
|
||||
heap_buffer = malloc(message_size);
|
||||
if (heap_buffer == NULL) {
|
||||
// Fall back to the truncated message in the on-stack buffer.
|
||||
message_size = GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE;
|
||||
} else {
|
||||
// Use the full message in the in-heap buffer.
|
||||
vsnprintf(heap_buffer, message_size, format, args_copy);
|
||||
message_buffer = heap_buffer;
|
||||
}
|
||||
}
|
||||
|
||||
ssize_t bytes_written;
|
||||
do {
|
||||
bytes_written = write(STDERR_FILENO, message_buffer, message_size);
|
||||
if (bytes_written > 0) {
|
||||
assert((size_t) bytes_written <= message_size);
|
||||
message_buffer += bytes_written;
|
||||
message_size -= bytes_written;
|
||||
}
|
||||
} while (bytes_written >= 0 && message_size != 0);
|
||||
|
||||
cleanup:
|
||||
free(heap_buffer);
|
||||
va_end(args_copy);
|
||||
}
|
||||
137
gpt_oss/metal/source/matmul.metal
Normal file
137
gpt_oss/metal/source/matmul.metal
Normal file
@@ -0,0 +1,137 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_compute>
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
|
||||
// Each simdgroup reduces all channels of the input and computes a single channel of the output
|
||||
// + Efficient synchronization
|
||||
// + Sequential memory access within a warp
|
||||
// Each threadgroup computes (simdgroups_per_threadgroup) consecutive output channels
|
||||
// + Reuse input vector from threadgroup memory
|
||||
// + Avoid synchronization across warps when doing reduction
|
||||
|
||||
kernel void gptoss_f32_bf16w_matmul(
|
||||
constant gptoss_matmul_args& args [[ buffer(0) ]],
|
||||
const device float4* input [[ buffer(1) ]],
|
||||
const device bfloat4* weight [[ buffer(2) ]],
|
||||
const device bfloat* bias [[ buffer(3) ]],
|
||||
device float* output [[ buffer(4) ]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
||||
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
||||
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
||||
{
|
||||
const uint simdgroup_size = 32;
|
||||
|
||||
const uint num_column_vecs = args.num_column_vecs;
|
||||
const uint row = gid.x * num_simdgroups + simdgroup_idx;
|
||||
|
||||
input += gid.y * num_column_vecs + simdgroup_tid;
|
||||
weight += num_column_vecs * row + simdgroup_tid;
|
||||
bias += row;
|
||||
output += gid.y * args.num_rows + row;
|
||||
|
||||
uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
||||
|
||||
float4 sum4 = 0.0f;
|
||||
do {
|
||||
const bfloat4 w = *weight;
|
||||
const float4 i = *input;
|
||||
sum4 = metal::fma(static_cast<float4>(w), i, sum4);
|
||||
|
||||
weight += simdgroup_size;
|
||||
input += simdgroup_size;
|
||||
} while (--num_iter != 0);
|
||||
const float2 sum2 = sum4.xy + sum4.zw;
|
||||
float sum = sum2.x + sum2.y;
|
||||
sum = metal::simd_sum(sum);
|
||||
if (metal::simd_is_first()) {
|
||||
sum += static_cast<float>(*bias);
|
||||
if (args.add) {
|
||||
*output += sum;
|
||||
} else {
|
||||
*output = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void gptoss_f32_bf16w_unembedding(
|
||||
constant gptoss_unembedding_args& args [[ buffer(0) ]],
|
||||
const device float4* input [[ buffer(1) ]],
|
||||
const device bfloat4* weight [[ buffer(2) ]],
|
||||
device float* output [[ buffer(3) ]],
|
||||
device metal::atomic_ulong* argmax [[ buffer(4) ]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
||||
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
||||
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
||||
{
|
||||
const uint simdgroup_size = 32;
|
||||
threadgroup uint2 threadgroup_buffer[32];
|
||||
|
||||
const uint num_column_vecs = args.num_column_vecs;
|
||||
const uint row_start = gid.x * args.num_rows_per_threadgroup + simdgroup_idx;
|
||||
const uint row_end = metal::min(gid.x * args.num_rows_per_threadgroup + args.num_rows_per_threadgroup, args.num_rows);
|
||||
const uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
||||
|
||||
input += gid.y * num_column_vecs + simdgroup_tid;
|
||||
weight += num_column_vecs * row_start + simdgroup_tid;
|
||||
output += gid.y * args.num_rows + row_start;
|
||||
|
||||
uint2 row_sum{0xFFFFFFFFul, 0xFFFFFFFFul};
|
||||
for (uint row = row_start; row < row_end; row += num_simdgroups) {
|
||||
uint n = num_iter;
|
||||
|
||||
float4 sum4 = 0.0f;
|
||||
do {
|
||||
const bfloat4 w = *weight;
|
||||
const float4 i = *input;
|
||||
|
||||
sum4 = metal::fma(static_cast<float4>(w), i, sum4);
|
||||
|
||||
weight += simdgroup_size;
|
||||
input += simdgroup_size;
|
||||
} while (--n != 0);
|
||||
input -= num_iter * simdgroup_size;
|
||||
weight -= num_iter * simdgroup_size;
|
||||
|
||||
const float2 sum2 = sum4.xy + sum4.zw;
|
||||
float sum = sum2.x + sum2.y;
|
||||
sum = metal::simd_sum(sum);
|
||||
uint sum_bits = as_type<uint>(sum);
|
||||
if (static_cast<int>(sum_bits) >= 0) {
|
||||
sum_bits ^= 0x7FFFFFFFu;
|
||||
}
|
||||
row_sum = as_type<uint2>(metal::min(as_type<ulong>(row_sum), as_type<ulong>(uint2{row, sum_bits})));
|
||||
if (metal::simd_is_first()) {
|
||||
*output = sum;
|
||||
}
|
||||
|
||||
weight += num_column_vecs * num_simdgroups;
|
||||
output += num_simdgroups;
|
||||
}
|
||||
if (metal::simd_is_first()) {
|
||||
threadgroup_buffer[simdgroup_idx] = row_sum;
|
||||
}
|
||||
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
||||
if (simdgroup_idx == 0) {
|
||||
// Min-Reduce threadgroup_buffer
|
||||
if (simdgroup_tid < num_simdgroups) {
|
||||
row_sum = threadgroup_buffer[simdgroup_tid];
|
||||
}
|
||||
const uint sum_bits = row_sum.y;
|
||||
const uint sum_bits_min = metal::simd_min(sum_bits);
|
||||
const uint row_min = metal::simd_min(sum_bits == sum_bits_min ? row_sum.x : 0xFFFFFFFFu);
|
||||
if (metal::simd_is_first()) {
|
||||
const uint2 threadgroup_output{row_min, sum_bits_min};
|
||||
atomic_min_explicit(&argmax[gid.y], as_type<ulong>(threadgroup_output), metal::memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
817
gpt_oss/metal/source/metal-kernels.c
Normal file
817
gpt_oss/metal/source/metal-kernels.c
Normal file
@@ -0,0 +1,817 @@
|
||||
#include <inttypes.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
#include <internal/log.h>
|
||||
#include <internal/math.h>
|
||||
#include <internal/metal.h>
|
||||
#include <internal/metal-kernels.h>
|
||||
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* u32_fill_random_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint64_t num_elements,
|
||||
uint64_t rng_seed,
|
||||
uint64_t rng_offset)
|
||||
{
|
||||
if (command_buffer->object == NULL || u32_fill_random_fn->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = u32_fill_random_fn->max_threadgroup_threads;
|
||||
} else if (threadgroup_size > u32_fill_random_fn->max_threadgroup_threads) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const size_t num_vecs = num_elements;
|
||||
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
||||
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
||||
const struct gptoss_u32_fill_random_args args = {
|
||||
.num_vecs = num_vecs,
|
||||
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
||||
.seed = rng_seed,
|
||||
.offset = rng_offset,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, u32_fill_random_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_threadgroups, 1, 1,
|
||||
sizeof(args), &args,
|
||||
1, &output_buffer, &output_offset);
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_fill_random_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint64_t num_elements,
|
||||
uint64_t rng_seed,
|
||||
uint64_t rng_offset,
|
||||
float rng_min,
|
||||
float rng_max)
|
||||
{
|
||||
if (command_buffer->object == NULL || f32_fill_random_fn->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = f32_fill_random_fn->max_threadgroup_threads;
|
||||
} else if (threadgroup_size > f32_fill_random_fn->max_threadgroup_threads) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (rng_min >= rng_max) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const size_t num_vecs = num_elements;
|
||||
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
||||
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
||||
const struct gptoss_f32_fill_random_args args = {
|
||||
.num_vecs = num_vecs,
|
||||
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
||||
.seed = rng_seed,
|
||||
.offset = rng_offset,
|
||||
.scale = (rng_max - rng_min) * 0x1.0p-32f,
|
||||
.bias = (rng_min + rng_max) * 0.5f,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_fill_random_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_threadgroups, 1, 1,
|
||||
sizeof(args), &args,
|
||||
1, &output_buffer, &output_offset);
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* bf16_fill_random_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint64_t num_elements,
|
||||
uint64_t rng_seed,
|
||||
uint64_t rng_offset,
|
||||
float rng_min,
|
||||
float rng_max)
|
||||
{
|
||||
if (command_buffer->object == NULL || bf16_fill_random_fn->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = bf16_fill_random_fn->max_threadgroup_threads;
|
||||
} else if (threadgroup_size > bf16_fill_random_fn->max_threadgroup_threads) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (rng_min >= rng_max) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const size_t num_vecs = num_elements;
|
||||
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
||||
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
||||
const struct gptoss_f32_fill_random_args args = {
|
||||
.num_vecs = num_vecs,
|
||||
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
||||
.seed = rng_seed,
|
||||
.offset = rng_offset,
|
||||
.scale = (rng_max - rng_min) * 0x1.0p-32f,
|
||||
.bias = (rng_min + rng_max) * 0.5f,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, bf16_fill_random_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_threadgroups, 1, 1,
|
||||
sizeof(args), &args,
|
||||
1, &output_buffer, &output_offset);
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* mf4_f32_convert_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* block_buffer,
|
||||
const struct gptoss_metal_buffer* scale_buffer,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
uint64_t num_elements)
|
||||
{
|
||||
if (command_buffer->object == NULL || mf4_f32_convert_fn->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (num_elements % 32 != 0) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = mf4_f32_convert_fn->max_threadgroup_threads;
|
||||
} else if (threadgroup_size > mf4_f32_convert_fn->max_threadgroup_threads) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const size_t num_vecs = num_elements / 32;
|
||||
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
||||
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
||||
const struct gptoss_convert_args args = {
|
||||
.num_vecs = num_vecs,
|
||||
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, mf4_f32_convert_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_threadgroups, 1, 1,
|
||||
sizeof(args), &args,
|
||||
3, (const struct gptoss_metal_buffer *[]) {block_buffer, scale_buffer, output_buffer}, NULL);
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* bf16_f32_embeddings_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* token_buffer,
|
||||
size_t token_offset,
|
||||
const struct gptoss_metal_buffer* weight_buffer,
|
||||
size_t weight_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_channels)
|
||||
{
|
||||
if (command_buffer->object == NULL || bf16_f32_embeddings_fn->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (num_channels % 4 != 0) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = bf16_f32_embeddings_fn->max_threadgroup_threads;
|
||||
} else if (threadgroup_size > bf16_f32_embeddings_fn->max_threadgroup_threads) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const uint32_t num_vecs = num_channels / 4;
|
||||
const struct gptoss_embeddings_args args = {
|
||||
.num_vecs = num_vecs,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, bf16_f32_embeddings_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_tokens, 1, 1,
|
||||
sizeof(args), &args,
|
||||
3,
|
||||
(const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer},
|
||||
(const size_t[]) {token_offset, weight_offset, output_offset});
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_bf16w_rmsnorm_fn,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* weight_buffer,
|
||||
size_t weight_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_channels,
|
||||
float epsilon)
|
||||
{
|
||||
if (command_buffer->object == NULL || f32_bf16w_rmsnorm_fn->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (num_channels % 4 != 0) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (f32_bf16w_rmsnorm_fn->max_threadgroup_threads < 1024) {
|
||||
return gptoss_status_unsupported_system;
|
||||
}
|
||||
|
||||
if (f32_bf16w_rmsnorm_fn->simdgroup_threads != 32) {
|
||||
return gptoss_status_unsupported_system;
|
||||
}
|
||||
|
||||
const uint32_t num_vecs = num_channels / 4;
|
||||
const struct gptoss_rmsnorm_args args = {
|
||||
.num_vecs = num_vecs,
|
||||
.num_channels = (float) num_channels,
|
||||
.epsilon = epsilon,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_bf16w_rmsnorm_fn,
|
||||
/*threadgroup_size=*/1024, 1, 1,
|
||||
num_tokens, 1, 1,
|
||||
sizeof(args), &args,
|
||||
3,
|
||||
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer},
|
||||
(const size_t[]) {input_offset, weight_offset, output_offset});
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* weight_buffer,
|
||||
size_t weight_offset,
|
||||
const struct gptoss_metal_buffer* bias_buffer,
|
||||
size_t bias_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_cols,
|
||||
uint32_t num_rows)
|
||||
{
|
||||
if (command_buffer->object == NULL || f32_bf16w_matmul_fn->pipeline_state_object == NULL) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: invalid command buffer or pipeline state object");
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = f32_bf16w_matmul_fn->simdgroup_threads;
|
||||
} else if (threadgroup_size > f32_bf16w_matmul_fn->max_threadgroup_threads) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
|
||||
threadgroup_size, f32_bf16w_matmul_fn->max_threadgroup_threads);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (num_cols % 4 != 0) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
|
||||
num_cols);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_fn->simdgroup_threads;
|
||||
if (num_rows % num_simdgroups != 0) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
|
||||
num_rows, num_simdgroups);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const struct gptoss_matmul_args args = {
|
||||
.num_column_vecs = num_cols / 4,
|
||||
.num_rows = num_rows,
|
||||
.add = 0,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_bf16w_matmul_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_rows / num_simdgroups, num_tokens, 1,
|
||||
sizeof(args), &args,
|
||||
4,
|
||||
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer},
|
||||
(const size_t[]) {input_offset, weight_offset, bias_offset, output_offset});
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* weight_buffer,
|
||||
size_t weight_offset,
|
||||
const struct gptoss_metal_buffer* bias_buffer,
|
||||
size_t bias_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_cols,
|
||||
uint32_t num_rows)
|
||||
{
|
||||
if (command_buffer->object == NULL || f32_bf16w_matmul_fn->pipeline_state_object == NULL) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: invalid command buffer or pipeline state object");
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = f32_bf16w_matmul_fn->simdgroup_threads;
|
||||
} else if (threadgroup_size > f32_bf16w_matmul_fn->max_threadgroup_threads) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
|
||||
threadgroup_size, f32_bf16w_matmul_fn->max_threadgroup_threads);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (num_cols % 4 != 0) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
|
||||
num_cols);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_fn->simdgroup_threads;
|
||||
if (num_rows % num_simdgroups != 0) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
|
||||
num_rows, num_simdgroups);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const struct gptoss_matmul_args args = {
|
||||
.num_column_vecs = num_cols / 4,
|
||||
.num_rows = num_rows,
|
||||
.add = 1,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_bf16w_matmul_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_rows / num_simdgroups, num_tokens, 1,
|
||||
sizeof(args), &args,
|
||||
4,
|
||||
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer},
|
||||
(const size_t[]) {input_offset, weight_offset, bias_offset, output_offset});
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_bf16w_unembedding_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* weight_buffer,
|
||||
size_t weight_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
const struct gptoss_metal_buffer* argmax_buffer,
|
||||
size_t argmax_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_cols,
|
||||
uint32_t num_rows)
|
||||
{
|
||||
if (command_buffer->object == NULL || f32_bf16w_unembedding_fn->pipeline_state_object == NULL) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch: invalid command buffer or pipeline state object");
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = f32_bf16w_unembedding_fn->simdgroup_threads;
|
||||
} else if (threadgroup_size > f32_bf16w_unembedding_fn->max_threadgroup_threads) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
|
||||
threadgroup_size, f32_bf16w_unembedding_fn->max_threadgroup_threads);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (num_cols % 4 != 0) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
|
||||
num_cols);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const size_t num_simdgroups = threadgroup_size / f32_bf16w_unembedding_fn->simdgroup_threads;
|
||||
const size_t num_rows_per_threadgroup = math_ceil_div(num_rows, max_threadgroups * num_simdgroups) * num_simdgroups;
|
||||
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_rows, num_rows_per_threadgroup));
|
||||
const struct gptoss_unembedding_args args = {
|
||||
.num_column_vecs = num_cols / 4,
|
||||
.num_rows_per_threadgroup = num_rows_per_threadgroup,
|
||||
.num_rows = num_rows,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_bf16w_unembedding_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_threadgroups, num_tokens, 1,
|
||||
sizeof(args), &args,
|
||||
4,
|
||||
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer},
|
||||
(const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset});
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_mf4w_moe_matmul_swiglu_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* expert_buffer,
|
||||
size_t expert_offset,
|
||||
const struct gptoss_metal_buffer* weight_block_buffer,
|
||||
size_t weight_block_offset,
|
||||
const struct gptoss_metal_buffer* weight_scale_buffer,
|
||||
size_t weight_scale_offset,
|
||||
const struct gptoss_metal_buffer* bias_buffer,
|
||||
size_t bias_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
float swiglu_limit,
|
||||
uint32_t expert_stride,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_active_experts,
|
||||
uint32_t num_cols,
|
||||
uint32_t num_rows)
|
||||
{
|
||||
if (command_buffer->object == NULL || f32_mf4w_moe_matmul_swiglu_fn->pipeline_state_object == NULL) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: invalid command buffer or pipeline state object");
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = 2 * f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads;
|
||||
} else if (threadgroup_size > f32_mf4w_moe_matmul_swiglu_fn->max_threadgroup_threads) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
|
||||
threadgroup_size, f32_mf4w_moe_matmul_swiglu_fn->max_threadgroup_threads);
|
||||
return gptoss_status_invalid_argument;
|
||||
} else if (threadgroup_size % (2 * f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads)) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: threadgroup size (%zu) is not divisible by simdgroup size (%zu) multiplied by 2X",
|
||||
threadgroup_size, f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (num_cols % 32 != 0) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: number of columns (%" PRIu32 ") is not divisible by 32",
|
||||
num_cols);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
const size_t num_simdgroups = threadgroup_size / f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads;
|
||||
if ((2 * num_rows) % num_simdgroups != 0) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: "
|
||||
"the number of rows (%" PRIu32 ") multiplied by 2X is not divisible by the number of simdgroups (%zu)",
|
||||
num_rows, num_simdgroups);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const struct gptoss_moe_matmul_swiglu_args args = {
|
||||
.num_column_vecs = num_cols / 32,
|
||||
.num_rows = num_rows,
|
||||
.num_active_experts = num_active_experts,
|
||||
.weight_expert_stride = expert_stride,
|
||||
.output_expert_stride = num_rows * num_tokens,
|
||||
.swiglu_min = -swiglu_limit,
|
||||
.swiglu_max = swiglu_limit,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_mf4w_moe_matmul_swiglu_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
(2 * num_rows) / num_simdgroups, num_tokens, num_active_experts,
|
||||
sizeof(args), &args,
|
||||
6,
|
||||
(const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
|
||||
(const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset});
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_mf4w_moe_matmul_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* expert_buffer,
|
||||
size_t expert_offset,
|
||||
const struct gptoss_metal_buffer* weight_block_buffer,
|
||||
size_t weight_block_offset,
|
||||
const struct gptoss_metal_buffer* weight_scale_buffer,
|
||||
size_t weight_scale_offset,
|
||||
const struct gptoss_metal_buffer* bias_buffer,
|
||||
size_t bias_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t expert_stride,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_active_experts,
|
||||
uint32_t num_cols,
|
||||
uint32_t num_rows)
|
||||
{
|
||||
if (command_buffer->object == NULL || f32_mf4w_moe_matmul_fn->pipeline_state_object == NULL) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: invalid command buffer or pipeline state object");
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = f32_mf4w_moe_matmul_fn->simdgroup_threads;
|
||||
} else if (threadgroup_size > f32_mf4w_moe_matmul_fn->max_threadgroup_threads) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
|
||||
threadgroup_size, f32_mf4w_moe_matmul_fn->max_threadgroup_threads);
|
||||
return gptoss_status_invalid_argument;
|
||||
} else if (threadgroup_size % f32_mf4w_moe_matmul_fn->simdgroup_threads) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: threadgroup size (%zu) is not divisible by simdgroup size (%zu)",
|
||||
threadgroup_size, f32_mf4w_moe_matmul_fn->simdgroup_threads);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (num_cols % 32 != 0) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 32",
|
||||
num_cols);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
const size_t num_simdgroups = threadgroup_size / f32_mf4w_moe_matmul_fn->simdgroup_threads;
|
||||
if (num_rows % num_simdgroups != 0) {
|
||||
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: "
|
||||
"the number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
|
||||
num_rows, num_simdgroups);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const struct gptoss_moe_matmul_args args = {
|
||||
.num_column_vecs = num_cols / 32,
|
||||
.num_rows = num_rows,
|
||||
.num_active_experts = num_active_experts,
|
||||
.input_expert_stride = num_tokens * (num_cols / 32),
|
||||
.weight_expert_stride = expert_stride,
|
||||
.output_expert_stride = num_rows * num_tokens,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_mf4w_moe_matmul_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_rows / num_simdgroups, num_tokens, num_active_experts,
|
||||
sizeof(args), &args,
|
||||
6,
|
||||
(const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
|
||||
(const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset});
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_rope_fn,
|
||||
size_t threadgroup_size,
|
||||
const struct gptoss_metal_buffer* activations_buffer,
|
||||
float rope_base,
|
||||
float interpolation_scale,
|
||||
float yarn_offset,
|
||||
float yarn_scale,
|
||||
float yarn_multiplier,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_q_heads,
|
||||
uint32_t num_kv_heads,
|
||||
uint32_t attn_head_dim,
|
||||
uint32_t token_offset)
|
||||
{
|
||||
if (command_buffer->object == NULL || f32_rope_fn->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = f32_rope_fn->max_threadgroup_threads;
|
||||
} else if (threadgroup_size > f32_rope_fn->max_threadgroup_threads) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const size_t num_simdgroups = threadgroup_size / f32_rope_fn->simdgroup_threads;
|
||||
const uint32_t num_qk_heads = num_q_heads + num_kv_heads;
|
||||
if (num_qk_heads % num_simdgroups != 0) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const struct gptoss_rope_args args = {
|
||||
.token_stride = (num_q_heads + 2 * num_kv_heads) * (attn_head_dim / 2),
|
||||
.token_offset = token_offset,
|
||||
.freq_scale = -logf(rope_base) / (float) (int32_t) (attn_head_dim / 2),
|
||||
.interpolation_scale = interpolation_scale,
|
||||
.yarn_offset = yarn_offset,
|
||||
.yarn_scale = yarn_scale,
|
||||
.yarn_multiplier = yarn_multiplier,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_rope_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_qk_heads / num_simdgroups, num_tokens, 1,
|
||||
sizeof(args), &args,
|
||||
1, (const struct gptoss_metal_buffer *[]) {activations_buffer}, NULL);
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_accumulate_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* expert_buffer,
|
||||
size_t expert_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_channels,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_experts)
|
||||
{
|
||||
if (command_buffer->object == NULL || f32_accumulate_fn->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (num_channels% 4 != 0) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (threadgroup_size == 0) {
|
||||
threadgroup_size = f32_accumulate_fn->max_threadgroup_threads;
|
||||
} else if (threadgroup_size > f32_accumulate_fn->max_threadgroup_threads) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const size_t num_vecs = num_channels / 4;
|
||||
const size_t num_vecs_per_expert = num_vecs * num_tokens;
|
||||
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
||||
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
||||
const struct gptoss_accumulate_args args = {
|
||||
.num_vecs_per_expert = num_vecs_per_expert,
|
||||
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
||||
.num_vecs = num_vecs,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_accumulate_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_threadgroups, num_tokens, 1,
|
||||
sizeof(args), &args,
|
||||
3,
|
||||
(const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer},
|
||||
(const size_t[]) {input_offset, expert_offset, output_offset});
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_topk_fn,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t num_tokens,
|
||||
uint32_t num_experts,
|
||||
uint32_t num_active_experts)
|
||||
{
|
||||
if (command_buffer->object == NULL || f32_topk_fn->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (num_experts != 32 && num_experts != 128) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (num_active_experts != 4) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const struct gptoss_topk_args args = { 0 };
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_topk_fn,
|
||||
/*threadgroup_size=*/32, 1, 1,
|
||||
num_tokens, 1, 1,
|
||||
sizeof(args), &args,
|
||||
2,
|
||||
(const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer},
|
||||
(const size_t[]) {input_offset, output_offset});
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_sdpa_fn,
|
||||
const struct gptoss_metal_buffer* q_buffer,
|
||||
size_t q_offset,
|
||||
const struct gptoss_metal_buffer* k_buffer,
|
||||
size_t k_offset,
|
||||
const struct gptoss_metal_buffer* v_buffer,
|
||||
size_t v_offset,
|
||||
const struct gptoss_metal_buffer* s_buffer,
|
||||
size_t s_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
uint32_t window,
|
||||
uint32_t num_q_tokens,
|
||||
uint32_t num_kv_tokens,
|
||||
uint32_t num_q_heads,
|
||||
uint32_t num_kv_heads,
|
||||
uint32_t head_dim)
|
||||
{
|
||||
if (command_buffer->object == NULL || f32_sdpa_fn->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
if (num_q_heads != num_kv_heads * 8) {
|
||||
GPTOSS_LOG_ERROR("number of Q heads (%" PRIu32 ") must be 8 times the number of KV heads (%" PRIu32 ")",
|
||||
num_q_heads, num_kv_heads);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
if (head_dim != 64) {
|
||||
GPTOSS_LOG_ERROR("attention head dimension (%" PRIu32 ") must be 64", head_dim);
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const struct gptoss_sdpa_args args = {
|
||||
.qkv_dim = head_dim * (num_q_heads + 2 * num_kv_heads),
|
||||
.num_kv_tokens = num_kv_tokens,
|
||||
.window = window,
|
||||
};
|
||||
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_sdpa_fn,
|
||||
/*threadgroup_size=*/32, 1, 1,
|
||||
num_q_tokens, num_kv_heads, 1,
|
||||
sizeof(args), &args,
|
||||
5,
|
||||
(const struct gptoss_metal_buffer *[]) {q_buffer, k_buffer, v_buffer, s_buffer, output_buffer},
|
||||
(const size_t[]) {q_offset, k_offset, v_offset, s_offset, output_offset});
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* f32_softmax_fn,
|
||||
size_t threadgroup_size,
|
||||
size_t max_threadgroups,
|
||||
const struct gptoss_metal_buffer* score_buffer,
|
||||
size_t score_offset,
|
||||
const struct gptoss_metal_buffer* argmax_buffer,
|
||||
size_t argmax_offset,
|
||||
const struct gptoss_metal_buffer* prob_buffer,
|
||||
size_t prob_offset,
|
||||
const struct gptoss_metal_buffer* sum_buffer,
|
||||
size_t sum_offset,
|
||||
uint32_t num_channels,
|
||||
uint32_t num_tokens,
|
||||
float temperature,
|
||||
uint32_t* num_threadgroups_out,
|
||||
uint32_t* num_channels_per_threadgroup_out)
|
||||
{
|
||||
*num_threadgroups_out = 0;
|
||||
*num_channels_per_threadgroup_out = 0;
|
||||
if (command_buffer->object == NULL || f32_softmax_fn->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
const size_t num_vecs = num_channels;
|
||||
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
||||
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
||||
const struct gptoss_softmax_args args = {
|
||||
.num_vecs = num_vecs,
|
||||
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
||||
.max_threadgroups = max_threadgroups,
|
||||
.temperature = temperature,
|
||||
};
|
||||
|
||||
*num_threadgroups_out = num_threadgroups;
|
||||
*num_channels_per_threadgroup_out = num_vecs_per_threadgroup;
|
||||
return gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
command_buffer, f32_softmax_fn,
|
||||
threadgroup_size, 1, 1,
|
||||
num_threadgroups, num_tokens, 1,
|
||||
sizeof(args), &args,
|
||||
4,
|
||||
(const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer},
|
||||
(const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset});
|
||||
}
|
||||
453
gpt_oss/metal/source/metal.m
Normal file
453
gpt_oss/metal/source/metal.m
Normal file
@@ -0,0 +1,453 @@
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
#include <dispatch/dispatch.h>
|
||||
#include <mach-o/getsect.h>
|
||||
|
||||
#include <gpt-oss/types.h>
|
||||
|
||||
#include <internal/log.h>
|
||||
#include <internal/metal.h>
|
||||
|
||||
|
||||
static size_t gptoss_metal_device_get_core_count(id<MTLDevice> device) {
|
||||
if (!device) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint64_t target_registry_id = [device registryID];
|
||||
|
||||
io_iterator_t it = IO_OBJECT_NULL;
|
||||
const kern_return_t kr = IOServiceGetMatchingServices(
|
||||
kIOMainPortDefault,
|
||||
IOServiceMatching("IOAccelerator"),
|
||||
&it
|
||||
);
|
||||
if (kr != KERN_SUCCESS) {
|
||||
GPTOSS_LOG_ERROR("failed to find IOAccelerator objects: error %d", kr);
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t result = 0;
|
||||
for (io_object_t obj = IOIteratorNext(it); obj != IO_OBJECT_NULL; obj = IOIteratorNext(it)) {
|
||||
uint64_t registry_id = 0;
|
||||
if (IORegistryEntryGetRegistryEntryID(obj, ®istry_id) == KERN_SUCCESS &&
|
||||
registry_id == target_registry_id)
|
||||
{
|
||||
// Read "gpu-core-count" from this accelerator node
|
||||
const CFTypeRef value = IORegistryEntryCreateCFProperty(
|
||||
obj, CFSTR("gpu-core-count"), kCFAllocatorDefault, 0);
|
||||
if (value != NULL) {
|
||||
if (CFGetTypeID(value) == CFNumberGetTypeID()) {
|
||||
int32_t n = -1;
|
||||
if (CFNumberGetValue((CFNumberRef) value, kCFNumberSInt32Type, &n) && n > 0) {
|
||||
result = (size_t) n;
|
||||
}
|
||||
}
|
||||
CFRelease(value);
|
||||
}
|
||||
IOObjectRelease(obj);
|
||||
break;
|
||||
}
|
||||
IOObjectRelease(obj);
|
||||
}
|
||||
|
||||
IOObjectRelease(it);
|
||||
return result;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_device_create_system_default(
|
||||
struct gptoss_metal_device* device_out)
|
||||
{
|
||||
id<MTLDevice> device_obj = MTLCreateSystemDefaultDevice();
|
||||
if (device_obj == nil) {
|
||||
GPTOSS_LOG_ERROR("failed to create Metal device");
|
||||
return gptoss_status_unsupported_system;
|
||||
}
|
||||
|
||||
device_out->object = (void*) device_obj;
|
||||
device_out->num_cores = gptoss_metal_device_get_core_count(device_obj);
|
||||
device_out->max_buffer_size = (size_t) [device_obj maxBufferLength];
|
||||
device_out->max_threadgroup_memory = (size_t) [device_obj maxThreadgroupMemoryLength];
|
||||
const MTLSize max_threadgroup_threads = [device_obj maxThreadsPerThreadgroup];
|
||||
device_out->max_threadgroup_threads_x = (size_t) max_threadgroup_threads.width;
|
||||
device_out->max_threadgroup_threads_y = (size_t) max_threadgroup_threads.height;
|
||||
device_out->max_threadgroup_threads_z = (size_t) max_threadgroup_threads.depth;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_device_release(
|
||||
struct gptoss_metal_device* device)
|
||||
{
|
||||
if (device->object != NULL) {
|
||||
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
|
||||
[device_obj release];
|
||||
}
|
||||
memset(device, 0, sizeof(struct gptoss_metal_device));
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
extern const struct mach_header_64 __dso_handle;
|
||||
|
||||
enum gptoss_status gptoss_metal_library_create_default(
|
||||
const struct gptoss_metal_device* device,
|
||||
struct gptoss_metal_library* library_out)
|
||||
{
|
||||
enum gptoss_status status = gptoss_status_success;
|
||||
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
|
||||
id<MTLLibrary> library_obj = nil;
|
||||
NSError* error_obj = nil;
|
||||
NSString* error_string_obj = nil;
|
||||
dispatch_data_t library_blob = NULL;
|
||||
|
||||
unsigned long library_size = 0;
|
||||
uint8_t* library_data = getsectiondata(&__dso_handle, "__METAL", "__shaders", &library_size);
|
||||
if (library_data != NULL) {
|
||||
library_blob = dispatch_data_create(library_data, library_size, NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT);
|
||||
library_obj = [device_obj newLibraryWithData:library_blob error:&error_obj];
|
||||
if (library_obj == nil) {
|
||||
error_string_obj = [error_obj localizedDescription];
|
||||
GPTOSS_LOG_ERROR("failed to create Metal library: %s", [error_string_obj UTF8String]);
|
||||
status = gptoss_status_unsupported_system;
|
||||
goto cleanup;
|
||||
}
|
||||
} else {
|
||||
// Fall-back to loading from the bundle
|
||||
library_obj = [device_obj newDefaultLibrary];
|
||||
if (library_obj == nil) {
|
||||
GPTOSS_LOG_ERROR("failed to create Metal default library");
|
||||
status = gptoss_status_unsupported_system;
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
|
||||
*library_out = (struct gptoss_metal_library) {
|
||||
.object = (void*) library_obj,
|
||||
};
|
||||
|
||||
cleanup:
|
||||
if (library_blob != NULL) {
|
||||
dispatch_release(library_blob);
|
||||
}
|
||||
if (error_string_obj != nil) {
|
||||
[error_string_obj release];
|
||||
}
|
||||
if (error_obj != nil) {
|
||||
[error_obj release];
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_library_release(
|
||||
struct gptoss_metal_library* library)
|
||||
{
|
||||
if (library->object != NULL) {
|
||||
id<MTLLibrary> library_obj = (id<MTLLibrary>) library->object;
|
||||
[library_obj release];
|
||||
}
|
||||
memset(library, 0, sizeof(struct gptoss_metal_library));
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_function_create(
|
||||
const struct gptoss_metal_library* library,
|
||||
const char* name,
|
||||
struct gptoss_metal_function* function_out)
|
||||
{
|
||||
NSString* name_obj = nil;
|
||||
NSError* error_obj = nil;
|
||||
NSString* error_string_obj = nil;
|
||||
id<MTLFunction> function_obj = nil;
|
||||
enum gptoss_status status = gptoss_status_success;
|
||||
|
||||
id<MTLLibrary> library_obj = (id<MTLLibrary>) library->object;
|
||||
name_obj = [NSString stringWithUTF8String:name];
|
||||
function_obj = [library_obj newFunctionWithName:name_obj];
|
||||
if (function_obj == nil) {
|
||||
GPTOSS_LOG_ERROR("failed to create Metal function %s", name);
|
||||
status = gptoss_status_unsupported_system;
|
||||
goto cleanup;
|
||||
}
|
||||
id<MTLDevice> device_obj = [library_obj device];
|
||||
id<MTLComputePipelineState> pipeline_state_obj = [device_obj newComputePipelineStateWithFunction:function_obj error:&error_obj];
|
||||
if (pipeline_state_obj == nil) {
|
||||
error_string_obj = [error_obj localizedDescription];
|
||||
GPTOSS_LOG_ERROR("failed to create Metal compute pipeline state for function %s: %s",
|
||||
name, [error_string_obj UTF8String]);
|
||||
status = gptoss_status_unsupported_system;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
// Commit
|
||||
function_out->function_object = function_obj;
|
||||
function_out->pipeline_state_object = pipeline_state_obj;
|
||||
function_out->max_threadgroup_threads = (size_t) [pipeline_state_obj maxTotalThreadsPerThreadgroup];
|
||||
function_out->simdgroup_threads = (size_t) [pipeline_state_obj threadExecutionWidth];
|
||||
function_out->static_threadgroup_memory = (size_t) [pipeline_state_obj staticThreadgroupMemoryLength];
|
||||
|
||||
function_obj = nil;
|
||||
pipeline_state_obj = nil;
|
||||
|
||||
cleanup:
|
||||
if (name_obj != nil) {
|
||||
[name_obj release];
|
||||
}
|
||||
if (function_obj != nil) {
|
||||
[function_obj release];
|
||||
}
|
||||
if (error_string_obj != nil) {
|
||||
[error_string_obj release];
|
||||
}
|
||||
if (error_obj != nil) {
|
||||
[error_obj release];
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_function_release(
|
||||
struct gptoss_metal_function* function)
|
||||
{
|
||||
if (function->pipeline_state_object != NULL) {
|
||||
id<MTLComputePipelineState> pipeline_state_obj = (id<MTLComputePipelineState>) function->pipeline_state_object;
|
||||
[pipeline_state_obj release];
|
||||
}
|
||||
if (function->function_object != NULL) {
|
||||
id<MTLFunction> function_obj = (id<MTLFunction>) function->function_object;
|
||||
[function_obj release];
|
||||
}
|
||||
memset(function, 0, sizeof(struct gptoss_metal_function));
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_buffer_create(
|
||||
const struct gptoss_metal_device* device,
|
||||
size_t size,
|
||||
const void* data,
|
||||
struct gptoss_metal_buffer* buffer_out)
|
||||
{
|
||||
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
|
||||
id<MTLBuffer> buffer_obj = nil;
|
||||
if (data != NULL) {
|
||||
buffer_obj = [device_obj newBufferWithBytes:data length:size options:MTLResourceStorageModeShared];
|
||||
} else {
|
||||
buffer_obj = [device_obj newBufferWithLength:size options:MTLResourceStorageModeShared];
|
||||
}
|
||||
if (buffer_obj == nil) {
|
||||
GPTOSS_LOG_ERROR("failed to create Metal buffer of size %zu", size);
|
||||
return gptoss_status_unsupported_system;
|
||||
}
|
||||
buffer_out->object = (void*) buffer_obj;
|
||||
buffer_out->size = size;
|
||||
buffer_out->ptr = [buffer_obj contents];
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_buffer_wrap(
|
||||
const struct gptoss_metal_device* device,
|
||||
size_t size,
|
||||
const void* data,
|
||||
struct gptoss_metal_buffer* buffer_out)
|
||||
{
|
||||
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
|
||||
id<MTLBuffer> buffer_obj = [device_obj newBufferWithBytesNoCopy:(void*) data length:size options:MTLResourceStorageModeShared deallocator:nil];
|
||||
if (buffer_obj == nil) {
|
||||
GPTOSS_LOG_ERROR("failed to wrap Metal buffer of size %zu", size);
|
||||
return gptoss_status_unsupported_system;
|
||||
}
|
||||
buffer_out->object = (void*) buffer_obj;
|
||||
buffer_out->size = size;
|
||||
buffer_out->ptr = (void*) data;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_buffer_release(
|
||||
struct gptoss_metal_buffer* buffer)
|
||||
{
|
||||
if (buffer->object != NULL) {
|
||||
id<MTLBuffer> buffer_obj = (id<MTLBuffer>) buffer->object;
|
||||
[buffer_obj release];
|
||||
}
|
||||
memset(buffer, 0, sizeof(struct gptoss_metal_buffer));
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_queue_create(
|
||||
const struct gptoss_metal_device* device,
|
||||
struct gptoss_metal_command_queue* command_queue_out)
|
||||
{
|
||||
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
|
||||
id<MTLCommandQueue> command_queue_obj = [device_obj newCommandQueue];
|
||||
if (command_queue_obj == nil) {
|
||||
GPTOSS_LOG_ERROR("failed to create Metal command queue");
|
||||
return gptoss_status_unsupported_system;
|
||||
}
|
||||
command_queue_out->object = (void*) command_queue_obj;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_queue_release(
|
||||
struct gptoss_metal_command_queue* command_queue)
|
||||
{
|
||||
if (command_queue->object != NULL) {
|
||||
id<MTLCommandQueue> command_queue_obj = (id<MTLCommandQueue>) command_queue->object;
|
||||
[command_queue_obj release];
|
||||
}
|
||||
memset(command_queue, 0, sizeof(struct gptoss_metal_command_queue));
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_create(
|
||||
const struct gptoss_metal_command_queue* command_queue,
|
||||
struct gptoss_metal_command_buffer* command_buffer_out)
|
||||
{
|
||||
id<MTLCommandQueue> command_queue_obj = (id<MTLCommandQueue>) command_queue->object;
|
||||
id<MTLCommandBuffer> command_buffer_obj = [command_queue_obj commandBuffer];
|
||||
if (command_buffer_obj == nil) {
|
||||
GPTOSS_LOG_ERROR("failed to create Metal command buffer");
|
||||
return gptoss_status_unsupported_system;
|
||||
}
|
||||
[command_buffer_obj retain];
|
||||
command_buffer_out->object = (void*) command_buffer_obj;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_fill_buffer(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_buffer* buffer,
|
||||
size_t offset,
|
||||
size_t size,
|
||||
uint8_t fill_value)
|
||||
{
|
||||
if (command_buffer->object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
if (buffer->object == NULL) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
||||
id<MTLBuffer> buffer_obj = (id<MTLBuffer>) buffer->object;
|
||||
|
||||
id<MTLBlitCommandEncoder> command_encoder_obj = [command_buffer_obj blitCommandEncoder];
|
||||
|
||||
const NSRange range = NSMakeRange((NSUInteger) offset, (NSUInteger) size);
|
||||
[command_encoder_obj fillBuffer:buffer_obj range:range value:fill_value];
|
||||
[command_encoder_obj endEncoding];
|
||||
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_copy_buffer(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_buffer* input_buffer,
|
||||
size_t input_offset,
|
||||
const struct gptoss_metal_buffer* output_buffer,
|
||||
size_t output_offset,
|
||||
size_t size)
|
||||
{
|
||||
if (command_buffer->object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
if (input_buffer->object == NULL) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
if (output_buffer->object == NULL) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
||||
id<MTLBuffer> input_buffer_obj = (id<MTLBuffer>) input_buffer->object;
|
||||
id<MTLBuffer> output_buffer_obj = (id<MTLBuffer>) output_buffer->object;
|
||||
|
||||
id<MTLBlitCommandEncoder> command_encoder_obj = [command_buffer_obj blitCommandEncoder];
|
||||
|
||||
[command_encoder_obj copyFromBuffer:input_buffer_obj sourceOffset:(NSUInteger) input_offset
|
||||
toBuffer:output_buffer_obj destinationOffset:(NSUInteger) output_offset
|
||||
size:(NSUInteger) size];
|
||||
[command_encoder_obj endEncoding];
|
||||
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
const struct gptoss_metal_function* function,
|
||||
size_t threadgroup_size_x,
|
||||
size_t threadgroup_size_y,
|
||||
size_t threadgroup_size_z,
|
||||
size_t num_threadgroups_x,
|
||||
size_t num_threadgroups_y,
|
||||
size_t num_threadgroups_z,
|
||||
size_t params_size,
|
||||
const void* params,
|
||||
size_t num_buffers,
|
||||
const struct gptoss_metal_buffer** buffers,
|
||||
const size_t* buffer_offsets)
|
||||
{
|
||||
if (command_buffer->object == NULL || function->pipeline_state_object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
||||
id<MTLComputePipelineState> pipeline_state_obj = (id<MTLComputePipelineState>) function->pipeline_state_object;
|
||||
|
||||
id<MTLComputeCommandEncoder> command_encoder_obj = [command_buffer_obj computeCommandEncoder];
|
||||
|
||||
// Set kernel arguments
|
||||
[command_encoder_obj setComputePipelineState:pipeline_state_obj];
|
||||
[command_encoder_obj setBytes:params length:params_size atIndex:0];
|
||||
for (size_t i = 0; i < num_buffers; ++i) {
|
||||
id<MTLBuffer> buffer_obj = (id<MTLBuffer>) buffers[i]->object;
|
||||
const NSUInteger offset = buffer_offsets == NULL ? 0 : (NSUInteger) buffer_offsets[i];
|
||||
[command_encoder_obj setBuffer:buffer_obj offset:offset atIndex:i + 1];
|
||||
}
|
||||
|
||||
// Dispatch kernel
|
||||
const MTLSize threadgroup_size = MTLSizeMake(threadgroup_size_x, threadgroup_size_y, threadgroup_size_z);
|
||||
const MTLSize num_threadgroups = MTLSizeMake(num_threadgroups_x, num_threadgroups_y, num_threadgroups_z);
|
||||
[command_encoder_obj dispatchThreadgroups:num_threadgroups threadsPerThreadgroup:threadgroup_size];
|
||||
[command_encoder_obj endEncoding];
|
||||
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_commit(
|
||||
const struct gptoss_metal_command_buffer* command_buffer)
|
||||
{
|
||||
if (command_buffer->object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
||||
[command_buffer_obj commit];
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_wait_completion(
|
||||
const struct gptoss_metal_command_buffer* command_buffer,
|
||||
double* elapsed_seconds)
|
||||
{
|
||||
if (command_buffer->object == NULL) {
|
||||
return gptoss_status_invalid_state;
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
||||
[command_buffer_obj waitUntilCompleted];
|
||||
if (elapsed_seconds != NULL) {
|
||||
const CFTimeInterval start_time = [command_buffer_obj GPUStartTime];
|
||||
const CFTimeInterval end_time = [command_buffer_obj GPUEndTime];
|
||||
*elapsed_seconds = (double) end_time - (double) start_time;
|
||||
}
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status gptoss_metal_command_buffer_release(
|
||||
struct gptoss_metal_command_buffer* command_buffer)
|
||||
{
|
||||
if (command_buffer->object != NULL) {
|
||||
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
||||
[command_buffer_obj release];
|
||||
}
|
||||
memset(command_buffer, 0, sizeof(struct gptoss_metal_command_buffer));
|
||||
return gptoss_status_success;
|
||||
}
|
||||
560
gpt_oss/metal/source/model.c
Normal file
560
gpt_oss/metal/source/model.c
Normal file
@@ -0,0 +1,560 @@
|
||||
#include <assert.h>
|
||||
#include <inttypes.h>
|
||||
#include <stdatomic.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <errno.h> // errno, EISDIR, ENOENT, ENOTDIR
|
||||
#include <fcntl.h> // open
|
||||
#include <mach/vm_page_size.h> // vm_page_size
|
||||
#include <sys/mman.h> // mmap, PROT_READ, MAP_PRIVATE
|
||||
#include <sys/stat.h> // fstat, stat
|
||||
#include <sys/types.h> // off_t, ssize_t
|
||||
#include <unistd.h> // close
|
||||
|
||||
#include <gpt-oss.h>
|
||||
|
||||
#include "internal/datatype.h"
|
||||
#include "internal/kernel-args.h" // gptoss_expert_prediction
|
||||
#include "internal/log.h"
|
||||
#include "internal/uuid.h"
|
||||
#include "internal/storage.h"
|
||||
#include "internal/math.h"
|
||||
#include "internal/model.h"
|
||||
|
||||
|
||||
static size_t round_up_to_page_size(size_t bytes) {
|
||||
const size_t page_size_mask = (size_t) vm_page_size - 1;
|
||||
if ((bytes & page_size_mask) != 0) {
|
||||
bytes |= page_size_mask;
|
||||
bytes += 1;
|
||||
}
|
||||
return bytes;
|
||||
}
|
||||
|
||||
static size_t round_down_to_page_size(size_t bytes) {
|
||||
const size_t page_size_mask = (size_t) vm_page_size - 1;
|
||||
return bytes & ~page_size_mask;
|
||||
}
|
||||
|
||||
static enum gptoss_status read_fd(int fd, void* data, size_t size, const char* path) {
|
||||
assert(fd != -1);
|
||||
assert(data != NULL);
|
||||
assert(size != 0);
|
||||
|
||||
size_t bytes_to_read = size;
|
||||
char* current_byte = (char*) data;
|
||||
do {
|
||||
const ssize_t read_result = read(fd, current_byte, bytes_to_read);
|
||||
if (read_result < 0) {
|
||||
GPTOSS_LOG_ERROR("reading %zu bytes from file %s failed with error %d",
|
||||
size, path, errno);
|
||||
return gptoss_status_io_error;
|
||||
}
|
||||
current_byte += (size_t) read_result;
|
||||
bytes_to_read -= (size_t) read_result;
|
||||
} while (bytes_to_read != 0);
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
static void prefetch_fd(int fd, size_t offset, size_t size, const char* path) {
|
||||
// radvisory.ra_count is int, so we can't prefetch 2GB+ at once
|
||||
const size_t prefetch_max = round_down_to_page_size((size_t) INT_MAX);
|
||||
do {
|
||||
const size_t prefetch_size = math_min(size, prefetch_max);
|
||||
const struct radvisory ra = {
|
||||
.ra_offset = offset,
|
||||
.ra_count = (int) prefetch_size,
|
||||
};
|
||||
if (fcntl(fd, F_RDADVISE, &ra) == -1) {
|
||||
GPTOSS_LOG_WARNING("fcntl(%s, F_RDADVISE, .ra_offset=%zu, .ra_count=%d) failed with error %d\n",
|
||||
path, (size_t) ra.ra_offset, ra.ra_count, errno);
|
||||
return;
|
||||
}
|
||||
offset += prefetch_size;
|
||||
size -= prefetch_size;
|
||||
} while (size != 0);
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
|
||||
const char* path,
|
||||
gptoss_model_t* model_out)
|
||||
{
|
||||
*model_out = NULL;
|
||||
|
||||
enum gptoss_status status = gptoss_status_success;
|
||||
struct gptoss_model* model = NULL;
|
||||
struct gptoss_tokenizer* tokenizer = NULL;
|
||||
int fd = -1;
|
||||
size_t file_offset = 0;
|
||||
|
||||
fd = open(path, O_RDONLY);
|
||||
if (fd == -1) {
|
||||
GPTOSS_LOG_ERROR("open(%s) failed with error %d", path, errno);
|
||||
switch (errno) {
|
||||
case EISDIR:
|
||||
case ENOENT:
|
||||
case ENOTDIR:
|
||||
status = gptoss_status_invalid_argument;
|
||||
break;
|
||||
default:
|
||||
status = gptoss_status_io_error;
|
||||
break;
|
||||
}
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
struct gptoss_file_header file_header;
|
||||
status = read_fd(fd, &file_header, sizeof(file_header), path);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
file_offset += sizeof(file_header);
|
||||
|
||||
if (file_header.magic[0] != 'G' ||
|
||||
file_header.magic[1] != 'P' ||
|
||||
file_header.magic[2] != 'T' ||
|
||||
file_header.magic[3] != '-' ||
|
||||
file_header.magic[4] != 'O' ||
|
||||
file_header.magic[5] != 'S' ||
|
||||
file_header.magic[6] != 'S' ||
|
||||
file_header.magic[7] != ' ' ||
|
||||
file_header.magic[8] != 'v' ||
|
||||
file_header.magic[9] != '1' ||
|
||||
file_header.magic[10] != '.' ||
|
||||
file_header.magic[11] != '0' ||
|
||||
file_header.zero != 0)
|
||||
{
|
||||
GPTOSS_LOG_ERROR("invalid magic in file %s", path);
|
||||
status = gptoss_status_invalid_argument;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
struct gptoss_uuid model_uuid;
|
||||
status = read_fd(fd, &model_uuid, sizeof(model_uuid), path);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
file_offset += sizeof(model_uuid);
|
||||
|
||||
if (!gptoss_is_gptoss_model_uuid(&model_uuid)) {
|
||||
GPTOSS_LOG_ERROR("unsupported model UUID " UUID_FORMAT, UUID_ARGS(model_uuid));
|
||||
status = gptoss_status_invalid_argument;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
struct gptoss_gptoss_model_header model_header;
|
||||
status = read_fd(fd, &model_header, sizeof(model_header), path);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
file_offset += sizeof(model_header);
|
||||
|
||||
struct gptoss_uuid layout_uuid;
|
||||
status = read_fd(fd, &layout_uuid, sizeof(layout_uuid), path);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
file_offset += sizeof(layout_uuid);
|
||||
|
||||
if (!gptoss_is_applegpu_layout_uuid(&layout_uuid)) {
|
||||
GPTOSS_LOG_ERROR("unsupported layout UUID " UUID_FORMAT, UUID_ARGS(layout_uuid));
|
||||
status = gptoss_status_invalid_argument;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
const size_t model_size = sizeof(struct gptoss_model) + model_header.num_blocks * sizeof(struct gptoss_metal_buffer);
|
||||
model = malloc(model_size);
|
||||
if (model == NULL) {
|
||||
GPTOSS_LOG_ERROR("failed to allocate %zu bytes for model descriptor", model_size);
|
||||
status = gptoss_status_insufficient_memory;
|
||||
goto cleanup;
|
||||
}
|
||||
memset(model, 0, model_size);
|
||||
|
||||
atomic_store_explicit(&model->ref_count, 1, memory_order_relaxed);
|
||||
model->context_length = model_header.context_length;
|
||||
model->num_blocks = model_header.num_blocks;
|
||||
model->num_experts = model_header.num_experts;
|
||||
model->num_active_experts = model_header.num_active_experts;
|
||||
model->embedding_dim = model_header.embedding_dim;
|
||||
model->mlp_dim = model_header.mlp_dim;
|
||||
model->swiglu_limit = model_header.swiglu_limit;
|
||||
model->head_dim = model_header.head_dim;
|
||||
model->num_heads = model_header.num_heads;
|
||||
model->num_kv_heads = model_header.num_kv_heads;
|
||||
model->attention_window = model_header.attention_window;
|
||||
model->rope_theta = model_header.rope_theta;
|
||||
model->interpolation_scale = model_header.interpolation_scale;
|
||||
model->yarn_offset = model_header.yarn_offset;
|
||||
model->yarn_scale = model_header.yarn_scale;
|
||||
model->yarn_multiplier = model_header.yarn_multiplier;
|
||||
model->rmsnorm_epsilon = model_header.rmsnorm_epsilon;
|
||||
|
||||
model->max_batch_tokens = GPTOSS_DEFAULT_BATCH_SIZE;
|
||||
|
||||
struct gptoss_uuid tokenizer_uuid;
|
||||
status = read_fd(fd, &tokenizer_uuid, sizeof(tokenizer_uuid), path);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
file_offset += sizeof(tokenizer_uuid);
|
||||
|
||||
if (!gptoss_is_tiktoken_tokenizer_uuid(&tokenizer_uuid)) {
|
||||
GPTOSS_LOG_ERROR("unsupported tokenizer UUID " UUID_FORMAT, UUID_ARGS(tokenizer_uuid));
|
||||
status = gptoss_status_invalid_argument;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
struct gptoss_tiktoken_tokenizer_header tokenizer_header;
|
||||
status = read_fd(fd, &tokenizer_header, sizeof(tokenizer_header), path);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
file_offset += sizeof(tokenizer_header);
|
||||
|
||||
tokenizer = malloc(sizeof(struct gptoss_tokenizer));
|
||||
if (tokenizer == NULL) {
|
||||
GPTOSS_LOG_ERROR("failed to allocate %zu bytes for tokenizer descriptor", sizeof(struct gptoss_tokenizer));
|
||||
status = gptoss_status_insufficient_memory;
|
||||
goto cleanup;
|
||||
}
|
||||
memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));
|
||||
// Initialize all special token IDs to UINT32_MAX (0xFF in all bytes)
|
||||
memset(tokenizer->special_token_id, 0xFF, sizeof(tokenizer->special_token_id));
|
||||
|
||||
atomic_store_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
|
||||
tokenizer->num_special_tokens = tokenizer_header.num_special_tokens;
|
||||
tokenizer->num_text_tokens = tokenizer_header.num_text_tokens;
|
||||
model->vocabulary_size = tokenizer_header.num_special_tokens + tokenizer_header.num_text_tokens;
|
||||
for (uint32_t t = 0; t < tokenizer_header.num_special_tokens; t++) {
|
||||
struct gptoss_uuid token_uuid;
|
||||
status = read_fd(fd, &token_uuid, sizeof(token_uuid), path);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
file_offset += sizeof(token_uuid);
|
||||
|
||||
const enum gptoss_special_token token = gptoss_special_token_decode_uuid(&token_uuid);
|
||||
if (token != gptoss_special_token_invalid) {
|
||||
tokenizer->special_token_id[token - 1] = tokenizer_header.num_text_tokens + t;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t tokenizer_start_offset = file_offset;
|
||||
const size_t tokenizer_end_offset = tokenizer_start_offset + tokenizer_header.regex_size + tokenizer_header.tokens_size;
|
||||
const size_t tokenizer_mapping_start = round_down_to_page_size(tokenizer_start_offset);
|
||||
const size_t tokenizer_mapping_size = round_up_to_page_size(tokenizer_end_offset) - tokenizer_mapping_start;
|
||||
void* tokenizer_mapping_ptr = mmap(NULL, tokenizer_mapping_size, PROT_READ, MAP_PRIVATE, fd, tokenizer_mapping_start);
|
||||
if (tokenizer_mapping_ptr == (void*) -1) {
|
||||
GPTOSS_LOG_ERROR("failed to mmap(%s) tokenizer at offset %zu size %zu",
|
||||
path, tokenizer_mapping_start, tokenizer_mapping_size);
|
||||
status = gptoss_status_io_error;
|
||||
goto cleanup;
|
||||
}
|
||||
tokenizer->mapping_ptr = tokenizer_mapping_ptr;
|
||||
tokenizer->mapping_size = tokenizer_mapping_size;
|
||||
tokenizer->regex_ptr = (const char*) tokenizer_mapping_ptr + (tokenizer_start_offset - tokenizer_mapping_start);
|
||||
tokenizer->tokens_ptr = tokenizer->regex_ptr + tokenizer_header.regex_size;
|
||||
|
||||
if (madvise(tokenizer_mapping_ptr, tokenizer_mapping_size, MADV_RANDOM | MADV_WILLNEED) != 0) {
|
||||
GPTOSS_LOG_WARNING("madvise(%s, size=%zu) failed with error %d", path, tokenizer_mapping_size, errno);
|
||||
}
|
||||
|
||||
prefetch_fd(fd, tokenizer_mapping_start, tokenizer_mapping_size, path);
|
||||
|
||||
struct stat model_stat = {0};
|
||||
int stat_result = fstat(fd, &model_stat);
|
||||
if (stat_result != 0) {
|
||||
GPTOSS_LOG_ERROR("stat(%s) failed with error %d", path, errno);
|
||||
status = gptoss_status_io_error;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
const size_t model_mapping_start = round_up_to_page_size(tokenizer_end_offset);
|
||||
const size_t model_mapping_size = round_up_to_page_size((size_t) model_stat.st_size) - model_mapping_start;
|
||||
void* model_mapping_ptr = mmap(NULL, model_mapping_size, PROT_READ, MAP_PRIVATE, fd, model_mapping_start);
|
||||
if (model_mapping_ptr == (void*) -1) {
|
||||
GPTOSS_LOG_ERROR("failed to mmap(%s) model weights at offset %zu size %zu",
|
||||
path, model_mapping_start, model_mapping_size);
|
||||
status = gptoss_status_io_error;
|
||||
goto cleanup;
|
||||
}
|
||||
model->mapping_ptr = model_mapping_ptr;
|
||||
model->mapping_size = model_mapping_size;
|
||||
|
||||
if (madvise(model_mapping_ptr, model_mapping_size, MADV_SEQUENTIAL | MADV_WILLNEED) != 0) {
|
||||
GPTOSS_LOG_WARNING("madvise(%s, size=%zu) failed with error %d", path, model_mapping_size, errno);
|
||||
}
|
||||
|
||||
prefetch_fd(fd, model_mapping_start, model_mapping_size, path);
|
||||
|
||||
// Initialize Metal
|
||||
status = gptoss_metal_device_create_system_default(&model->device);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
model->max_threadgroups = model->device.num_cores * 3;
|
||||
status = gptoss_metal_command_queue_create(&model->device, &model->command_queue);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
// Metal kernels
|
||||
status = gptoss_metal_library_create_default(&model->device, &model->library);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_bf16_f32_embeddings", &model->bf16_f32_embeddings_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_rmsnorm", &model->f32_bf16w_rmsnorm_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_matmul", &model->f32_bf16w_matmul_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_unembedding", &model->f32_bf16w_unembedding_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_f32_rope", &model->f32_rope_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_matmul_swiglu", &model->f32_mf4w_moe_matmul_swiglu_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_matmul", &model->f32_mf4w_moe_matmul_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_f32_accumulate_e4", &model->f32_accumulate_e4_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_f32_topk_softmax_e32_k4", &model->f32_topk_softmax_e32_k4_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_f32_topk_softmax_e128_k4", &model->f32_topk_softmax_e128_k4_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_f32_softmax", &model->f32_softmax_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_function_create(&model->library, "gptoss_f32_sdpa_q8_d64", &model->f32_sdpa_q8_d64_fn);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
// Weight buffers
|
||||
const char* current_ptr = (const char*) model->mapping_ptr;
|
||||
|
||||
const size_t embedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
||||
model->attn_rmsnorm_gain_offset = embedding_weight_size;
|
||||
const size_t rmsnorm_weight_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
||||
model->attn_qkv_weight_offset = model->attn_rmsnorm_gain_offset + rmsnorm_weight_size;
|
||||
const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
|
||||
const size_t attn_qkv_weight_size = math_round_up_po2(attn_qkv_dim * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
||||
model->attn_qkv_bias_offset = model->attn_qkv_weight_offset + attn_qkv_weight_size;
|
||||
const size_t attn_qkv_bias_size = math_round_up_po2(attn_qkv_dim * sizeof(gptoss_bfloat16), 16);
|
||||
model->attn_sdpa_sink_offset = model->attn_qkv_bias_offset + attn_qkv_bias_size;
|
||||
const size_t attn_sink_weight_size = math_round_up_po2(model->num_heads * sizeof(gptoss_bfloat16), 16);
|
||||
model->attn_out_weight_offset = model->attn_sdpa_sink_offset + attn_sink_weight_size;
|
||||
const size_t attn_out_weight_size = math_round_up_po2(model->embedding_dim * model->num_heads * model->head_dim * sizeof(gptoss_bfloat16), 16);
|
||||
model->attn_out_bias_offset = model->attn_out_weight_offset + attn_out_weight_size;
|
||||
const size_t attn_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
||||
model->mlp_rmsnorm_gain_offset = model->attn_out_bias_offset + attn_out_bias_size;
|
||||
model->mlp_gate_weight_offset = model->mlp_rmsnorm_gain_offset + rmsnorm_weight_size;
|
||||
const size_t mlp_gate_weight_size = math_round_up_po2(model->num_experts * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
||||
model->mlp_gate_bias_offset = model->mlp_gate_weight_offset + mlp_gate_weight_size;
|
||||
const size_t mlp_gate_bias_size = math_round_up_po2(model->num_experts * sizeof(gptoss_bfloat16), 16);
|
||||
const size_t per_block_shared_weights_size =
|
||||
rmsnorm_weight_size + attn_qkv_weight_size + attn_qkv_bias_size + attn_sink_weight_size + attn_out_weight_size + attn_out_bias_size +
|
||||
rmsnorm_weight_size + mlp_gate_weight_size + mlp_gate_bias_size;
|
||||
model->rmsnorm_weight_offset = embedding_weight_size + model->num_blocks * per_block_shared_weights_size;
|
||||
model->unembedding_weight_offset = model->rmsnorm_weight_offset + rmsnorm_weight_size;
|
||||
const size_t unembedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
||||
|
||||
model->per_block_shared_weights_size = per_block_shared_weights_size;
|
||||
const size_t shared_weights_size =
|
||||
round_up_to_page_size(embedding_weight_size + rmsnorm_weight_size + unembedding_weight_size + model->num_blocks * per_block_shared_weights_size);
|
||||
|
||||
status = gptoss_metal_buffer_wrap(&model->device, shared_weights_size, current_ptr, &model->shared_weight_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to map expert-shared weight of size %zu onto a Metal buffer", shared_weights_size);
|
||||
goto cleanup;
|
||||
}
|
||||
current_ptr += shared_weights_size;
|
||||
model->weights_size += shared_weights_size;
|
||||
|
||||
const size_t mlp_swiglu_weight_block_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 2, 16);
|
||||
model->mlp_swiglu_scale_offset = mlp_swiglu_weight_block_size;
|
||||
const size_t mlp_swiglu_weight_scale_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 32, 16);
|
||||
model->mlp_swiglu_bias_offset = model->mlp_swiglu_scale_offset + mlp_swiglu_weight_scale_size;
|
||||
const size_t mlp_swiglu_bias_size = math_round_up_po2(2 * model->mlp_dim * sizeof(gptoss_bfloat16), 16);
|
||||
model->mlp_out_block_offset = model->mlp_swiglu_bias_offset + mlp_swiglu_bias_size;
|
||||
const size_t mlp_out_weight_block_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 2, 16);
|
||||
model->mlp_out_scale_offset = model->mlp_out_block_offset + mlp_out_weight_block_size;
|
||||
const size_t mlp_out_weight_scale_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 32, 16);
|
||||
model->mlp_out_bias_offset = model->mlp_out_scale_offset + mlp_out_weight_scale_size;
|
||||
const size_t mlp_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
||||
model->per_expert_block_weight_size =
|
||||
mlp_swiglu_weight_block_size + mlp_swiglu_weight_scale_size + mlp_swiglu_bias_size + mlp_out_weight_block_size + mlp_out_weight_scale_size + mlp_out_bias_size;
|
||||
const size_t moe_block_weight_size = round_up_to_page_size(model->num_experts * model->per_expert_block_weight_size);
|
||||
for (uint32_t n = 0; n < model->num_blocks; n++) {
|
||||
status = gptoss_metal_buffer_wrap(&model->device, moe_block_weight_size, current_ptr, &model->block_weight_buffers[n]);
|
||||
if (status != gptoss_status_success) {
|
||||
GPTOSS_LOG_ERROR("failed to map block #%" PRIu32 " MoE weight of size %zu onto a Metal buffer",
|
||||
n, moe_block_weight_size);
|
||||
goto cleanup;
|
||||
}
|
||||
current_ptr += moe_block_weight_size;
|
||||
model->weights_size += moe_block_weight_size;
|
||||
}
|
||||
|
||||
// Activation buffers
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &model->residual_activation_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &model->rmsnorm_activation_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &model->qkv_activation_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &model->sdpa_activation_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(float), NULL, &model->gate_activation_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &model->expert_activation_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &model->swiglu_activation_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &model->moe_activation_buffer);
|
||||
if (status != gptoss_status_success) {
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
model->allocation_size =
|
||||
model->residual_activation_buffer.size + model->rmsnorm_activation_buffer.size +
|
||||
model->qkv_activation_buffer.size + model->sdpa_activation_buffer.size +
|
||||
model->gate_activation_buffer.size + model->expert_activation_buffer.size + model->swiglu_activation_buffer.size + model->moe_activation_buffer.size;
|
||||
|
||||
// Commit tokenizer
|
||||
model->tokenizer = tokenizer;
|
||||
tokenizer = NULL;
|
||||
|
||||
// Commit model
|
||||
*model_out = model;
|
||||
model = NULL;
|
||||
|
||||
cleanup:
|
||||
if (fd != -1) {
|
||||
close(fd);
|
||||
fd = -1;
|
||||
}
|
||||
gptoss_model_release(model); // does nothing if model is NULL
|
||||
gptoss_tokenizer_release(tokenizer); // does nothing if tokenizer is NULL
|
||||
return status;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_model_get_tokenizer(
|
||||
gptoss_model_t model,
|
||||
gptoss_tokenizer_t* tokenizer_out)
|
||||
{
|
||||
gptoss_tokenizer_t tokenizer = model->tokenizer;
|
||||
atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
|
||||
*tokenizer_out = tokenizer;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_model_get_max_context_length(
|
||||
gptoss_model_t model,
|
||||
size_t* max_context_length_out)
|
||||
{
|
||||
*max_context_length_out = model->context_length;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_model_retain(
|
||||
gptoss_model_t model)
|
||||
{
|
||||
atomic_fetch_add_explicit(&model->ref_count, 1, memory_order_relaxed);
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_model_release(
|
||||
gptoss_model_t model)
|
||||
{
|
||||
if (model != NULL) {
|
||||
if (atomic_fetch_sub_explicit(&model->ref_count, 1, memory_order_acq_rel) == 1) {
|
||||
gptoss_tokenizer_release(model->tokenizer);
|
||||
|
||||
// Activation buffers
|
||||
gptoss_metal_buffer_release(&model->residual_activation_buffer);
|
||||
gptoss_metal_buffer_release(&model->rmsnorm_activation_buffer);
|
||||
gptoss_metal_buffer_release(&model->qkv_activation_buffer);
|
||||
gptoss_metal_buffer_release(&model->sdpa_activation_buffer);
|
||||
gptoss_metal_buffer_release(&model->gate_activation_buffer);
|
||||
gptoss_metal_buffer_release(&model->expert_activation_buffer);
|
||||
gptoss_metal_buffer_release(&model->swiglu_activation_buffer);
|
||||
gptoss_metal_buffer_release(&model->moe_activation_buffer);
|
||||
|
||||
// Weight buffers
|
||||
gptoss_metal_buffer_release(&model->shared_weight_buffer);
|
||||
for (uint32_t n = 0; n < model->num_blocks; n++) {
|
||||
gptoss_metal_buffer_release(&model->block_weight_buffers[n]);
|
||||
}
|
||||
|
||||
// Metal kernels
|
||||
gptoss_metal_function_release(&model->bf16_f32_embeddings_fn);
|
||||
gptoss_metal_function_release(&model->f32_bf16w_rmsnorm_fn);
|
||||
gptoss_metal_function_release(&model->f32_bf16w_matmul_fn);
|
||||
gptoss_metal_function_release(&model->f32_bf16w_unembedding_fn);
|
||||
gptoss_metal_function_release(&model->f32_rope_fn);
|
||||
gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_swiglu_fn);
|
||||
gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_fn);
|
||||
gptoss_metal_function_release(&model->f32_accumulate_e4_fn);
|
||||
gptoss_metal_function_release(&model->f32_topk_softmax_e32_k4_fn);
|
||||
gptoss_metal_function_release(&model->f32_topk_softmax_e128_k4_fn);
|
||||
gptoss_metal_function_release(&model->f32_softmax_fn);
|
||||
gptoss_metal_function_release(&model->f32_sdpa_q8_d64_fn);
|
||||
gptoss_metal_library_release(&model->library);
|
||||
|
||||
gptoss_metal_command_queue_release(&model->command_queue);
|
||||
gptoss_metal_device_release(&model->device);
|
||||
// Weight buffers
|
||||
|
||||
if (model->mapping_ptr != NULL && model->mapping_size != 0) {
|
||||
if (munmap(model->mapping_ptr, model->mapping_size) != 0) {
|
||||
GPTOSS_LOG_WARNING("munmap for model weight mapping failed with error %d", errno);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t model_size = sizeof(struct gptoss_model) + model->num_blocks * sizeof(struct gptoss_metal_buffer);
|
||||
memset(model, 0, model_size);
|
||||
free(model);
|
||||
}
|
||||
}
|
||||
return gptoss_status_success;
|
||||
}
|
||||
218
gpt_oss/metal/source/moematmul.metal
Normal file
218
gpt_oss/metal/source/moematmul.metal
Normal file
@@ -0,0 +1,218 @@
|
||||
#include <metal_common>
|
||||
#include <metal_compute>
|
||||
#include <metal_math>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
|
||||
// Each simdgroup reduces all channels of the input and computes a single channel of the output
|
||||
// + Efficient synchronization
|
||||
// + Sequential memory access within a warp
|
||||
// Each threadgroup computes (simdgroups_per_threadgroup) consecutive output channels
|
||||
// + Reuse input vector from threadgroup memory
|
||||
// + Avoid synchronization across warps when doing reduction
|
||||
|
||||
kernel void gptoss_f32_mf4w_moe_matmul_swiglu(
|
||||
constant gptoss_moe_matmul_swiglu_args& args [[ buffer(0) ]],
|
||||
const device float4* input [[ buffer(1) ]],
|
||||
const device gptoss_expert_prediction* expert [[ buffer(2) ]],
|
||||
const device uint4* weight_blocks [[ buffer(3) ]],
|
||||
const device uchar* weight_scales [[ buffer(4) ]],
|
||||
const device bfloat* bias [[ buffer(5) ]],
|
||||
device float* output [[ buffer(6) ]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
||||
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
||||
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
||||
{
|
||||
const uint simdgroup_size = 32;
|
||||
threadgroup float threadgroup_buffer[32];
|
||||
|
||||
const uint num_column_vecs = args.num_column_vecs;
|
||||
const uint row = gid.x * num_simdgroups + simdgroup_idx;
|
||||
const uint expert_id = expert[gid.y * args.num_active_experts + gid.z].expert_id;
|
||||
|
||||
input += 8 * (gid.y * num_column_vecs + simdgroup_tid);
|
||||
weight_blocks = (const device uint4*) ((uintptr_t) (weight_blocks + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
|
||||
weight_scales = (const device uchar*) ((uintptr_t) (weight_scales + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
|
||||
bias = (const device bfloat*) ((uintptr_t) (bias + row) + expert_id * args.weight_expert_stride);
|
||||
output += gid.y * args.num_rows + gid.x * (num_simdgroups / 2) + gid.z * args.output_expert_stride;
|
||||
|
||||
uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
||||
|
||||
float4 sum4 = 0.0f;
|
||||
do {
|
||||
const uint4 wblock = *weight_blocks;
|
||||
const float wscale = as_type<float>(static_cast<uint>(*weight_scales) << 23);
|
||||
uint4 wblock02468ACEGIKMOQSU = wblock + wblock;
|
||||
uint4 wblock13579BDFHJLNPRTV = wblock >> 3;
|
||||
wblock02468ACEGIKMOQSU &= 0x1E1E1E1Eu;
|
||||
wblock13579BDFHJLNPRTV &= 0x1E1E1E1Eu;
|
||||
wblock02468ACEGIKMOQSU += 0x70707070u;
|
||||
wblock13579BDFHJLNPRTV += 0x70707070u;
|
||||
wblock02468ACEGIKMOQSU &= 0x8E8E8E8Eu;
|
||||
wblock13579BDFHJLNPRTV &= 0x8E8E8E8Eu;
|
||||
const uint4 wblock26AEIMQU = wblock02468ACEGIKMOQSU & 0xFF00FF00u;
|
||||
const uint4 wblock048CGKOS = (wblock02468ACEGIKMOQSU << 8) & 0xFF00FF00u;
|
||||
const uint4 wblock37BFJNRV = wblock13579BDFHJLNPRTV & 0xFF00FF00u;
|
||||
const uint4 wblock159DHLPT = (wblock13579BDFHJLNPRTV << 8) & 0xFF00FF00u;
|
||||
const float4 w048C = static_cast<float4>(as_type<half4>(wblock048CGKOS.xy));
|
||||
const float4 wGKOS = static_cast<float4>(as_type<half4>(wblock048CGKOS.zw));
|
||||
const float4 w26AE = static_cast<float4>(as_type<half4>(wblock26AEIMQU.xy));
|
||||
const float4 wIMQU = static_cast<float4>(as_type<half4>(wblock26AEIMQU.zw));
|
||||
const float4 w159D = static_cast<float4>(as_type<half4>(wblock159DHLPT.xy));
|
||||
const float4 wHLPT = static_cast<float4>(as_type<half4>(wblock159DHLPT.zw));
|
||||
const float4 w37BF = static_cast<float4>(as_type<half4>(wblock37BFJNRV.xy));
|
||||
const float4 wJNRV = static_cast<float4>(as_type<half4>(wblock37BFJNRV.zw));
|
||||
|
||||
const float4 w0123 = (float4) { w048C.x, w159D.x, w26AE.x, w37BF.x };
|
||||
const float4 w4567 = (float4) { w048C.y, w159D.y, w26AE.y, w37BF.y };
|
||||
const float4 w89AB = (float4) { w048C.z, w159D.z, w26AE.z, w37BF.z };
|
||||
const float4 wCDEF = (float4) { w048C.w, w159D.w, w26AE.w, w37BF.w };
|
||||
const float4 wGHIJ = (float4) { wGKOS.x, wHLPT.x, wIMQU.x, wJNRV.x };
|
||||
const float4 wKLMN = (float4) { wGKOS.y, wHLPT.y, wIMQU.y, wJNRV.y };
|
||||
const float4 wOPQR = (float4) { wGKOS.z, wHLPT.z, wIMQU.z, wJNRV.z };
|
||||
const float4 wSTUV = (float4) { wGKOS.w, wHLPT.w, wIMQU.w, wJNRV.w };
|
||||
|
||||
const float4 i0123 = input[0];
|
||||
const float4 i4567 = input[1];
|
||||
const float4 i89AB = input[2];
|
||||
const float4 iCDEF = input[3];
|
||||
const float4 iGHIJ = input[4];
|
||||
const float4 iKLMN = input[5];
|
||||
const float4 iOPQR = input[6];
|
||||
const float4 iSTUV = input[7];
|
||||
|
||||
float4 psum0 = i0123 * w0123;
|
||||
float4 psum1 = i4567 * w4567;
|
||||
psum0 = metal::fma(i89AB, w89AB, psum0);
|
||||
psum1 = metal::fma(iCDEF, wCDEF, psum1);
|
||||
psum0 = metal::fma(iGHIJ, wGHIJ, psum0);
|
||||
psum1 = metal::fma(iKLMN, wKLMN, psum1);
|
||||
psum0 = metal::fma(iOPQR, wOPQR, psum0);
|
||||
psum1 = metal::fma(iSTUV, wSTUV, psum1);
|
||||
sum4 = metal::fma(psum0, wscale, sum4);
|
||||
sum4 = metal::fma(psum1, wscale, sum4);
|
||||
|
||||
weight_blocks += simdgroup_size;
|
||||
weight_scales += simdgroup_size;
|
||||
input += 8 * simdgroup_size;
|
||||
} while (--num_iter != 0);
|
||||
const float2 sum2 = sum4.xy + sum4.zw;
|
||||
float sum = sum2.x + sum2.y;
|
||||
sum = metal::simd_sum(sum);
|
||||
if (metal::simd_is_first()) {
|
||||
sum += static_cast<float>(*bias);
|
||||
threadgroup_buffer[simdgroup_idx] = sum;
|
||||
}
|
||||
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
||||
if (tid * 2 < num_simdgroups) {
|
||||
const float2 x = reinterpret_cast<const threadgroup float2*>(threadgroup_buffer)[tid];
|
||||
const float swish_x = metal::min(x.x, args.swiglu_max);
|
||||
const float linear_x = metal::clamp(x.y, args.swiglu_min, args.swiglu_max);
|
||||
const float alpha = 1.702f;
|
||||
const float swish_y = swish_x / (1.0f + metal::precise::exp(-alpha * swish_x));
|
||||
const float swiglu_y = metal::fma(swish_y, linear_x, swish_y);
|
||||
output[tid] = swiglu_y;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void gptoss_f32_mf4w_moe_matmul(
|
||||
constant gptoss_moe_matmul_args& args [[ buffer(0) ]],
|
||||
const device float4* input [[ buffer(1) ]],
|
||||
const device gptoss_expert_prediction* expert [[ buffer(2) ]],
|
||||
const device uint4* weight_blocks [[ buffer(3) ]],
|
||||
const device uchar* weight_scales [[ buffer(4) ]],
|
||||
const device bfloat* bias [[ buffer(5) ]],
|
||||
device float* output [[ buffer(6) ]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]],
|
||||
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
||||
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
||||
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
||||
{
|
||||
const uint simdgroup_size = 32;
|
||||
|
||||
const uint num_column_vecs = args.num_column_vecs;
|
||||
const uint row = gid.x * num_simdgroups + simdgroup_idx;
|
||||
const uint expert_id = expert[gid.y * args.num_active_experts + gid.z].expert_id;
|
||||
|
||||
input += 8 * (gid.y * num_column_vecs + simdgroup_tid + gid.z * args.input_expert_stride);
|
||||
weight_blocks = (const device uint4*) ((uintptr_t) (weight_blocks + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
|
||||
weight_scales = (const device uchar*) ((uintptr_t) (weight_scales + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
|
||||
bias = (const device bfloat*) ((uintptr_t) (bias + row) + expert_id * args.weight_expert_stride);
|
||||
output += gid.y * args.num_rows + row + gid.z * args.output_expert_stride;
|
||||
|
||||
uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
||||
|
||||
float4 sum4 = 0.0f;
|
||||
do {
|
||||
const uint4 wblock = *weight_blocks;
|
||||
const float wscale = as_type<float>(static_cast<uint>(*weight_scales) << 23);
|
||||
uint4 wblock02468ACEGIKMOQSU = wblock + wblock;
|
||||
uint4 wblock13579BDFHJLNPRTV = wblock >> 3;
|
||||
wblock02468ACEGIKMOQSU &= 0x1E1E1E1Eu;
|
||||
wblock13579BDFHJLNPRTV &= 0x1E1E1E1Eu;
|
||||
wblock02468ACEGIKMOQSU += 0x70707070u;
|
||||
wblock13579BDFHJLNPRTV += 0x70707070u;
|
||||
wblock02468ACEGIKMOQSU &= 0x8E8E8E8Eu;
|
||||
wblock13579BDFHJLNPRTV &= 0x8E8E8E8Eu;
|
||||
const uint4 wblock26AEIMQU = wblock02468ACEGIKMOQSU & 0xFF00FF00u;
|
||||
const uint4 wblock048CGKOS = (wblock02468ACEGIKMOQSU << 8) & 0xFF00FF00u;
|
||||
const uint4 wblock37BFJNRV = wblock13579BDFHJLNPRTV & 0xFF00FF00u;
|
||||
const uint4 wblock159DHLPT = (wblock13579BDFHJLNPRTV << 8) & 0xFF00FF00u;
|
||||
const float4 w048C = static_cast<float4>(as_type<half4>(wblock048CGKOS.xy));
|
||||
const float4 wGKOS = static_cast<float4>(as_type<half4>(wblock048CGKOS.zw));
|
||||
const float4 w26AE = static_cast<float4>(as_type<half4>(wblock26AEIMQU.xy));
|
||||
const float4 wIMQU = static_cast<float4>(as_type<half4>(wblock26AEIMQU.zw));
|
||||
const float4 w159D = static_cast<float4>(as_type<half4>(wblock159DHLPT.xy));
|
||||
const float4 wHLPT = static_cast<float4>(as_type<half4>(wblock159DHLPT.zw));
|
||||
const float4 w37BF = static_cast<float4>(as_type<half4>(wblock37BFJNRV.xy));
|
||||
const float4 wJNRV = static_cast<float4>(as_type<half4>(wblock37BFJNRV.zw));
|
||||
|
||||
const float4 w0123 = (float4) { w048C.x, w159D.x, w26AE.x, w37BF.x };
|
||||
const float4 w4567 = (float4) { w048C.y, w159D.y, w26AE.y, w37BF.y };
|
||||
const float4 w89AB = (float4) { w048C.z, w159D.z, w26AE.z, w37BF.z };
|
||||
const float4 wCDEF = (float4) { w048C.w, w159D.w, w26AE.w, w37BF.w };
|
||||
const float4 wGHIJ = (float4) { wGKOS.x, wHLPT.x, wIMQU.x, wJNRV.x };
|
||||
const float4 wKLMN = (float4) { wGKOS.y, wHLPT.y, wIMQU.y, wJNRV.y };
|
||||
const float4 wOPQR = (float4) { wGKOS.z, wHLPT.z, wIMQU.z, wJNRV.z };
|
||||
const float4 wSTUV = (float4) { wGKOS.w, wHLPT.w, wIMQU.w, wJNRV.w };
|
||||
|
||||
const float4 i0123 = input[0];
|
||||
const float4 i4567 = input[1];
|
||||
const float4 i89AB = input[2];
|
||||
const float4 iCDEF = input[3];
|
||||
const float4 iGHIJ = input[4];
|
||||
const float4 iKLMN = input[5];
|
||||
const float4 iOPQR = input[6];
|
||||
const float4 iSTUV = input[7];
|
||||
|
||||
float4 psum0 = i0123 * w0123;
|
||||
float4 psum1 = i4567 * w4567;
|
||||
psum0 = metal::fma(i89AB, w89AB, psum0);
|
||||
psum1 = metal::fma(iCDEF, wCDEF, psum1);
|
||||
psum0 = metal::fma(iGHIJ, wGHIJ, psum0);
|
||||
psum1 = metal::fma(iKLMN, wKLMN, psum1);
|
||||
psum0 = metal::fma(iOPQR, wOPQR, psum0);
|
||||
psum1 = metal::fma(iSTUV, wSTUV, psum1);
|
||||
sum4 = metal::fma(psum0, wscale, sum4);
|
||||
sum4 = metal::fma(psum1, wscale, sum4);
|
||||
|
||||
weight_blocks += simdgroup_size;
|
||||
weight_scales += simdgroup_size;
|
||||
input += 8 * simdgroup_size;
|
||||
} while (--num_iter != 0);
|
||||
const float2 sum2 = sum4.xy + sum4.zw;
|
||||
float sum = sum2.x + sum2.y;
|
||||
sum = metal::simd_sum(sum);
|
||||
if (metal::simd_is_first()) {
|
||||
sum += static_cast<float>(*bias);
|
||||
*output = sum;
|
||||
}
|
||||
}
|
||||
97
gpt_oss/metal/source/random.metal
Normal file
97
gpt_oss/metal/source/random.metal
Normal file
@@ -0,0 +1,97 @@
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
|
||||
inline static uint rng_squares32(ulong offset, ulong seed) {
|
||||
const ulong y = offset * seed;
|
||||
const ulong z = y + seed;
|
||||
|
||||
/* Round 1 */
|
||||
ulong x = y * y + y;
|
||||
x = metal::rotate(x, 32ul);
|
||||
|
||||
/* Round 2 */
|
||||
x = x * x + z;
|
||||
x = metal::rotate(x, 32ul);
|
||||
|
||||
/* Round 3 */
|
||||
x = x * x + y;
|
||||
x = metal::rotate(x, 32ul);
|
||||
|
||||
/* Round 4 */
|
||||
x = x * x + z;
|
||||
return as_type<uint2>(x).y;
|
||||
}
|
||||
|
||||
kernel void gptoss_u32_fill_random(
|
||||
constant gptoss_u32_fill_random_args& args [[ buffer(0) ]],
|
||||
device uint* output [[ buffer(1) ]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_position_in_threadgroup]],
|
||||
uint threadgroup_size [[ threads_per_threadgroup ]])
|
||||
{
|
||||
const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
|
||||
const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
|
||||
const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
|
||||
const ulong thread_start = threadgroup_start + tid;
|
||||
uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
|
||||
|
||||
output += thread_start;
|
||||
ulong offset = args.offset + thread_start;
|
||||
for (; num_iter != 0; num_iter--) {
|
||||
*output = rng_squares32(offset, args.seed);
|
||||
output += threadgroup_size;
|
||||
offset += threadgroup_size;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void gptoss_f32_fill_random(
|
||||
constant gptoss_f32_fill_random_args& args [[ buffer(0) ]],
|
||||
device float* output [[ buffer(1) ]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_position_in_threadgroup]],
|
||||
uint threadgroup_size [[ threads_per_threadgroup ]])
|
||||
{
|
||||
const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
|
||||
const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
|
||||
const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
|
||||
const ulong thread_start = threadgroup_start + tid;
|
||||
uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
|
||||
|
||||
output += thread_start;
|
||||
ulong offset = args.offset + thread_start;
|
||||
for (; num_iter != 0; num_iter--) {
|
||||
const uint word = rng_squares32(offset, args.seed);
|
||||
*output = metal::fma(static_cast<float>(as_type<int>(word)), args.scale, args.bias);
|
||||
output += threadgroup_size;
|
||||
offset += threadgroup_size;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void gptoss_bf16_fill_random(
|
||||
constant gptoss_f32_fill_random_args& args [[ buffer(0) ]],
|
||||
device bfloat* output [[ buffer(1) ]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_position_in_threadgroup]],
|
||||
uint threadgroup_size [[ threads_per_threadgroup ]])
|
||||
{
|
||||
const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
|
||||
const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
|
||||
const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
|
||||
const ulong thread_start = threadgroup_start + tid;
|
||||
uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
|
||||
|
||||
output += thread_start;
|
||||
ulong offset = args.offset + thread_start;
|
||||
for (; num_iter != 0; num_iter--) {
|
||||
const uint word = rng_squares32(offset, args.seed);
|
||||
*output = static_cast<bfloat>(metal::fma(static_cast<float>(as_type<int>(word)), args.scale, args.bias));
|
||||
output += threadgroup_size;
|
||||
offset += threadgroup_size;
|
||||
}
|
||||
}
|
||||
54
gpt_oss/metal/source/rmsnorm.metal
Normal file
54
gpt_oss/metal/source/rmsnorm.metal
Normal file
@@ -0,0 +1,54 @@
|
||||
#include <metal_compute>
|
||||
#include <metal_math>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
|
||||
[[max_total_threads_per_threadgroup(1024)]]
|
||||
kernel void gptoss_f32_bf16w_rmsnorm(
|
||||
constant gptoss_rmsnorm_args& args [[ buffer(0) ]],
|
||||
const device float4* input [[ buffer(1) ]],
|
||||
const device bfloat4* weights [[ buffer(2) ]],
|
||||
device float4* output [[ buffer(3) ]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_position_in_threadgroup]],
|
||||
uint threadgroup_size [[ threads_per_threadgroup ]])
|
||||
{
|
||||
const uint simdgroup_size = 32;
|
||||
threadgroup float threadgroup_buffer[32];
|
||||
|
||||
input += gid * args.num_vecs;
|
||||
output += gid * args.num_vecs;
|
||||
|
||||
float4 sumsq4 = 0.0f;
|
||||
for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {
|
||||
const float4 val = input[i];
|
||||
sumsq4 = metal::fma(val, val, sumsq4);
|
||||
}
|
||||
|
||||
// Tree-reduce sumsq within thread, then all-reduce within threadgroup.
|
||||
const float2 sumsq2 = sumsq4.xy + sumsq4.zw;
|
||||
float sumsq = sumsq2.x + sumsq2.y;
|
||||
// Warning: this all-reduce works only for simdgroup of 32 threads and threadgroup of 32*32=1024 threads.
|
||||
sumsq = metal::simd_sum(sumsq);
|
||||
if (metal::simd_is_first()) {
|
||||
const uint simdgroup_idx = tid / simdgroup_size;
|
||||
threadgroup_buffer[simdgroup_idx] = sumsq;
|
||||
}
|
||||
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
||||
const uint simdgroup_tid = tid % simdgroup_size;
|
||||
sumsq = threadgroup_buffer[simdgroup_tid];
|
||||
sumsq = metal::simd_sum(sumsq);
|
||||
|
||||
const float avgsq = sumsq / args.num_channels;
|
||||
const float scale = metal::precise::rsqrt(avgsq + args.epsilon);
|
||||
for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {
|
||||
const float4 val = input[i] * scale;
|
||||
const float4 weight_val = static_cast<float4>(weights[i]);
|
||||
output[i] = val * weight_val;
|
||||
}
|
||||
}
|
||||
38
gpt_oss/metal/source/rope.metal
Normal file
38
gpt_oss/metal/source/rope.metal
Normal file
@@ -0,0 +1,38 @@
|
||||
#include <metal_common>
|
||||
#include <metal_math>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
|
||||
// Each thread handles 2 head elements.
|
||||
// Each simdgroup handles one head (64 head elements).
|
||||
|
||||
kernel void gptoss_f32_rope(
|
||||
constant gptoss_rope_args& args [[ buffer(0) ]],
|
||||
device float2* activations [[ buffer(1) ]],
|
||||
uint2 gid [[thread_position_in_grid]])
|
||||
{
|
||||
const uint num_head_dims = 64;
|
||||
const float head_idx = static_cast<float>(gid.x % (num_head_dims / 2));
|
||||
const uint token_idx = args.token_offset + gid.y;
|
||||
activations += gid.y * args.token_stride + gid.x;
|
||||
|
||||
const float2 input_vals = *activations;
|
||||
const float inv_extrapolation_freq = metal::precise::exp(head_idx * args.freq_scale);
|
||||
const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale;
|
||||
const float alpha = metal::saturate(metal::fma(head_idx, args.yarn_scale, args.yarn_offset));
|
||||
const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha);
|
||||
|
||||
const float phi = static_cast<float>(token_idx) * inv_freq;
|
||||
const float yarn_multiplier = args.yarn_multiplier;
|
||||
float cosphi;
|
||||
const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier;
|
||||
cosphi *= yarn_multiplier;
|
||||
|
||||
const float output_re = metal::fma(-input_vals.y, sinphi, input_vals.x * cosphi);
|
||||
const float output_im = metal::fma(input_vals.y, cosphi, input_vals.x * sinphi);
|
||||
*activations = (float2) { output_re, output_im };
|
||||
}
|
||||
60
gpt_oss/metal/source/sample.metal
Normal file
60
gpt_oss/metal/source/sample.metal
Normal file
@@ -0,0 +1,60 @@
|
||||
#include <metal_compute>
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
|
||||
kernel void gptoss_f32_softmax(
|
||||
constant gptoss_softmax_args& args [[ buffer(0) ]],
|
||||
const device float* score [[ buffer(1) ]],
|
||||
const device uint2* argmax [[ buffer(2) ]],
|
||||
device float* prob [[ buffer(3) ]],
|
||||
device float* sum [[ buffer(4) ]],
|
||||
uint tidx [[thread_index_in_threadgroup]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint2 threadgroup_size [[threads_per_threadgroup]],
|
||||
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
||||
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
||||
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
||||
{
|
||||
threadgroup float threadgroup_sumexp[32];
|
||||
|
||||
score += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;
|
||||
prob += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;
|
||||
sum += gid.y * args.max_threadgroups;
|
||||
|
||||
uint max_bits = argmax[gid.y].y;
|
||||
if (static_cast<int>(max_bits) >= 0) {
|
||||
max_bits ^= 0x7FFFFFFFu;
|
||||
}
|
||||
const float max_val = as_type<float>(max_bits);
|
||||
float sum_exp = 0.0f;
|
||||
const uint num_vecs_per_threadgroup = metal::min(args.num_vecs - gid.x * args.num_vecs_per_threadgroup, args.num_vecs_per_threadgroup);
|
||||
for (uint i = tidx; i < num_vecs_per_threadgroup; i += threadgroup_size.x) {
|
||||
const float score_val = score[i];
|
||||
const float prob_val = metal::precise::exp((score_val - max_val) * args.temperature);
|
||||
prob[i] = prob_val;
|
||||
sum_exp += prob_val;
|
||||
}
|
||||
sum_exp = metal::simd_sum(sum_exp);
|
||||
if (metal::simd_is_first()) {
|
||||
threadgroup_sumexp[simdgroup_idx] = sum_exp;
|
||||
}
|
||||
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
||||
if (simdgroup_idx == 0) {
|
||||
// Sum-Reduce threadgroup_sumexp
|
||||
sum_exp = 0.0f;
|
||||
if (simdgroup_tid < num_simdgroups) {
|
||||
sum_exp = threadgroup_sumexp[simdgroup_tid];
|
||||
}
|
||||
sum_exp = metal::simd_sum(sum_exp);
|
||||
if (metal::simd_is_first()) {
|
||||
sum[gid.x] = sum_exp;
|
||||
}
|
||||
}
|
||||
}
|
||||
164
gpt_oss/metal/source/sdpa.metal
Normal file
164
gpt_oss/metal/source/sdpa.metal
Normal file
@@ -0,0 +1,164 @@
|
||||
#include <metal_geometric>
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
#include <metal_compute>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
// Each threadgroup handles 8 Q heads / 1 KV head for 1 token
|
||||
|
||||
[[max_total_threads_per_threadgroup(32)]]
|
||||
kernel void gptoss_f32_sdpa_q8_d64(
|
||||
constant gptoss_sdpa_args& args [[ buffer(0) ]],
|
||||
const device float* q [[ buffer(1) ]],
|
||||
const device float* k [[ buffer(2) ]],
|
||||
const device float* v [[ buffer(3) ]],
|
||||
const device bfloat* s [[ buffer(4) ]],
|
||||
device float* output [[ buffer(5) ]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_index_in_threadgroup]])
|
||||
{
|
||||
const uint num_q_heads = 64;
|
||||
const uint num_kv_heads = 8;
|
||||
const uint head_dim = 64;
|
||||
const uint qmul = 8;
|
||||
|
||||
const uint qt = gid.x; // Q token index
|
||||
const uint h = gid.y; // KV head index
|
||||
|
||||
q += qt * args.qkv_dim + h * (qmul * head_dim);
|
||||
k += h * head_dim;
|
||||
v += h * head_dim;
|
||||
output += qt * (num_q_heads * head_dim) + h * (qmul * head_dim);
|
||||
|
||||
float m0 = static_cast<float>(s[h * qmul + 0]);
|
||||
float m1 = static_cast<float>(s[h * qmul + 1]);
|
||||
float m2 = static_cast<float>(s[h * qmul + 2]);
|
||||
float m3 = static_cast<float>(s[h * qmul + 3]);
|
||||
float m4 = static_cast<float>(s[h * qmul + 4]);
|
||||
float m5 = static_cast<float>(s[h * qmul + 5]);
|
||||
float m6 = static_cast<float>(s[h * qmul + 6]);
|
||||
float m7 = static_cast<float>(s[h * qmul + 7]);
|
||||
|
||||
float l0 = 1.0f;
|
||||
float l1 = 1.0f;
|
||||
float l2 = 1.0f;
|
||||
float l3 = 1.0f;
|
||||
float l4 = 1.0f;
|
||||
float l5 = 1.0f;
|
||||
float l6 = 1.0f;
|
||||
float l7 = 1.0f;
|
||||
|
||||
float2 out0 = 0.0f;
|
||||
float2 out1 = 0.0f;
|
||||
float2 out2 = 0.0f;
|
||||
float2 out3 = 0.0f;
|
||||
float2 out4 = 0.0f;
|
||||
float2 out5 = 0.0f;
|
||||
float2 out6 = 0.0f;
|
||||
float2 out7 = 0.0f;
|
||||
|
||||
float2 q0 = reinterpret_cast<const device float2*>(q + 0 * head_dim)[tid];
|
||||
float2 q1 = reinterpret_cast<const device float2*>(q + 1 * head_dim)[tid];
|
||||
float2 q2 = reinterpret_cast<const device float2*>(q + 2 * head_dim)[tid];
|
||||
float2 q3 = reinterpret_cast<const device float2*>(q + 3 * head_dim)[tid];
|
||||
float2 q4 = reinterpret_cast<const device float2*>(q + 4 * head_dim)[tid];
|
||||
float2 q5 = reinterpret_cast<const device float2*>(q + 5 * head_dim)[tid];
|
||||
float2 q6 = reinterpret_cast<const device float2*>(q + 6 * head_dim)[tid];
|
||||
float2 q7 = reinterpret_cast<const device float2*>(q + 7 * head_dim)[tid];
|
||||
|
||||
const uint kt_end = qt + args.num_kv_tokens + 1;
|
||||
const uint kt_start = metal::subsat(kt_end, args.window);
|
||||
k += 2 * num_kv_heads * head_dim * kt_start;
|
||||
v += 2 * num_kv_heads * head_dim * kt_start;
|
||||
for (uint kt = kt_start; kt < kt_end; kt++) {
|
||||
const float2 kval = reinterpret_cast<const device float2*>(k)[tid];
|
||||
k += 2 * num_kv_heads * head_dim;
|
||||
|
||||
float qk0 = metal::dot(q0, kval);
|
||||
float qk1 = metal::dot(q1, kval);
|
||||
float qk2 = metal::dot(q2, kval);
|
||||
float qk3 = metal::dot(q3, kval);
|
||||
float qk4 = metal::dot(q4, kval);
|
||||
float qk5 = metal::dot(q5, kval);
|
||||
float qk6 = metal::dot(q6, kval);
|
||||
float qk7 = metal::dot(q7, kval);
|
||||
|
||||
qk0 = metal::simd_sum(qk0);
|
||||
qk1 = metal::simd_sum(qk1);
|
||||
qk2 = metal::simd_sum(qk2);
|
||||
qk3 = metal::simd_sum(qk3);
|
||||
qk4 = metal::simd_sum(qk4);
|
||||
qk5 = metal::simd_sum(qk5);
|
||||
qk6 = metal::simd_sum(qk6);
|
||||
qk7 = metal::simd_sum(qk7);
|
||||
|
||||
const float new_m0 = metal::max(m0, qk0);
|
||||
const float new_m1 = metal::max(m1, qk1);
|
||||
const float new_m2 = metal::max(m2, qk2);
|
||||
const float new_m3 = metal::max(m3, qk3);
|
||||
const float new_m4 = metal::max(m4, qk4);
|
||||
const float new_m5 = metal::max(m5, qk5);
|
||||
const float new_m6 = metal::max(m6, qk6);
|
||||
const float new_m7 = metal::max(m7, qk7);
|
||||
|
||||
const float alpha0 = metal::fast::exp(m0 - new_m0);
|
||||
const float alpha1 = metal::fast::exp(m1 - new_m1);
|
||||
const float alpha2 = metal::fast::exp(m2 - new_m2);
|
||||
const float alpha3 = metal::fast::exp(m3 - new_m3);
|
||||
const float alpha4 = metal::fast::exp(m4 - new_m4);
|
||||
const float alpha5 = metal::fast::exp(m5 - new_m5);
|
||||
const float alpha6 = metal::fast::exp(m6 - new_m6);
|
||||
const float alpha7 = metal::fast::exp(m7 - new_m7);
|
||||
|
||||
qk0 = metal::fast::exp(qk0 - new_m0);
|
||||
qk1 = metal::fast::exp(qk1 - new_m1);
|
||||
qk2 = metal::fast::exp(qk2 - new_m2);
|
||||
qk3 = metal::fast::exp(qk3 - new_m3);
|
||||
qk4 = metal::fast::exp(qk4 - new_m4);
|
||||
qk5 = metal::fast::exp(qk5 - new_m5);
|
||||
qk6 = metal::fast::exp(qk6 - new_m6);
|
||||
qk7 = metal::fast::exp(qk7 - new_m7);
|
||||
|
||||
l0 = metal::fma(l0, alpha0, qk0);
|
||||
l1 = metal::fma(l1, alpha1, qk1);
|
||||
l2 = metal::fma(l2, alpha2, qk2);
|
||||
l3 = metal::fma(l3, alpha3, qk3);
|
||||
l4 = metal::fma(l4, alpha4, qk4);
|
||||
l5 = metal::fma(l5, alpha5, qk5);
|
||||
l6 = metal::fma(l6, alpha6, qk6);
|
||||
l7 = metal::fma(l7, alpha7, qk7);
|
||||
|
||||
m0 = new_m0;
|
||||
m1 = new_m1;
|
||||
m2 = new_m2;
|
||||
m3 = new_m3;
|
||||
m4 = new_m4;
|
||||
m5 = new_m5;
|
||||
m6 = new_m6;
|
||||
m7 = new_m7;
|
||||
|
||||
const float2 vval = reinterpret_cast<const device float2*>(v)[tid];
|
||||
v += 2 * num_kv_heads * head_dim;
|
||||
out0 = metal::fma(vval, qk0, out0 * alpha0);
|
||||
out1 = metal::fma(vval, qk1, out1 * alpha1);
|
||||
out2 = metal::fma(vval, qk2, out2 * alpha2);
|
||||
out3 = metal::fma(vval, qk3, out3 * alpha3);
|
||||
out4 = metal::fma(vval, qk4, out4 * alpha4);
|
||||
out5 = metal::fma(vval, qk5, out5 * alpha5);
|
||||
out6 = metal::fma(vval, qk6, out6 * alpha6);
|
||||
out7 = metal::fma(vval, qk7, out7 * alpha7);
|
||||
}
|
||||
reinterpret_cast<device float2*>(output + 0 * head_dim)[tid] = out0 / l0;
|
||||
reinterpret_cast<device float2*>(output + 1 * head_dim)[tid] = out1 / l1;
|
||||
reinterpret_cast<device float2*>(output + 2 * head_dim)[tid] = out2 / l2;
|
||||
reinterpret_cast<device float2*>(output + 3 * head_dim)[tid] = out3 / l3;
|
||||
reinterpret_cast<device float2*>(output + 4 * head_dim)[tid] = out4 / l4;
|
||||
reinterpret_cast<device float2*>(output + 5 * head_dim)[tid] = out5 / l5;
|
||||
reinterpret_cast<device float2*>(output + 6 * head_dim)[tid] = out6 / l6;
|
||||
reinterpret_cast<device float2*>(output + 7 * head_dim)[tid] = out7 / l7;
|
||||
}
|
||||
106
gpt_oss/metal/source/tokenizer.c
Normal file
106
gpt_oss/metal/source/tokenizer.c
Normal file
@@ -0,0 +1,106 @@
|
||||
#include <assert.h>
|
||||
#include <stdatomic.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <errno.h>
|
||||
#include <sys/mman.h>
|
||||
|
||||
#include <gpt-oss.h>
|
||||
|
||||
#include "internal/log.h"
|
||||
#include "internal/model.h"
|
||||
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_special_token_id(
|
||||
gptoss_tokenizer_t tokenizer,
|
||||
enum gptoss_special_token token_type,
|
||||
uint32_t* token_id_out)
|
||||
{
|
||||
uint32_t token_id = UINT32_MAX;
|
||||
if (token_type != gptoss_special_token_invalid && token_type < gptoss_special_token_max)
|
||||
{
|
||||
token_id = tokenizer->special_token_id[(uint32_t) token_type - 1];
|
||||
}
|
||||
if (token_id == UINT32_MAX) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
*token_id_out = token_id;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_text_tokens(
|
||||
gptoss_tokenizer_t tokenizer,
|
||||
uint32_t* num_text_tokens_out)
|
||||
{
|
||||
*num_text_tokens_out = tokenizer->num_text_tokens;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_special_tokens(
|
||||
gptoss_tokenizer_t tokenizer,
|
||||
uint32_t* num_special_tokens_out)
|
||||
{
|
||||
*num_special_tokens_out = tokenizer->num_special_tokens;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_tokens(
|
||||
gptoss_tokenizer_t tokenizer,
|
||||
uint32_t* num_tokens_out)
|
||||
{
|
||||
*num_tokens_out = tokenizer->num_text_tokens + tokenizer->num_special_tokens;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_decode(
|
||||
gptoss_tokenizer_t tokenizer,
|
||||
uint32_t token_id,
|
||||
const void** token_ptr_out,
|
||||
size_t* token_size_out)
|
||||
{
|
||||
if (token_id >= tokenizer->num_text_tokens) {
|
||||
return gptoss_status_invalid_argument;
|
||||
}
|
||||
|
||||
const char* token_ptr = (const char*) tokenizer->tokens_ptr;
|
||||
for (uint32_t t = 0; t < token_id; t++) {
|
||||
// Reading unaligned uint16_t
|
||||
uint16_t token_length;
|
||||
memcpy(&token_length, token_ptr, sizeof(token_length));
|
||||
|
||||
token_ptr += (size_t) token_length + sizeof(uint16_t);
|
||||
}
|
||||
|
||||
*token_ptr_out = (const void*) (token_ptr + sizeof(uint16_t));
|
||||
*token_size_out = (size_t) *token_ptr;
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_retain(
|
||||
gptoss_tokenizer_t tokenizer)
|
||||
{
|
||||
atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
|
||||
return gptoss_status_success;
|
||||
}
|
||||
|
||||
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_release(
|
||||
gptoss_tokenizer_t tokenizer)
|
||||
{
|
||||
if (tokenizer != NULL) {
|
||||
if (atomic_fetch_sub_explicit(&tokenizer->ref_count, 1, memory_order_acquire) == 1) {
|
||||
if (tokenizer->mapping_ptr != NULL && tokenizer->mapping_size != 0) {
|
||||
if (munmap(tokenizer->mapping_ptr, tokenizer->mapping_size) != 0) {
|
||||
GPTOSS_LOG_WARNING("munmap for tokenizer mapping failed with error %d", errno);
|
||||
}
|
||||
}
|
||||
|
||||
memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));
|
||||
free(tokenizer);
|
||||
}
|
||||
}
|
||||
return gptoss_status_success;
|
||||
}
|
||||
197
gpt_oss/metal/source/topk.metal
Normal file
197
gpt_oss/metal/source/topk.metal
Normal file
@@ -0,0 +1,197 @@
|
||||
#include <metal_compute>
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
|
||||
[[max_total_threads_per_threadgroup(32)]]
|
||||
kernel void gptoss_f32_topk_softmax_e128_k4(
|
||||
constant gptoss_topk_args& args [[ buffer(0) ]],
|
||||
const device float4* input [[ buffer(1) ]],
|
||||
device gptoss_expert_prediction* output [[ buffer(2) ]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_position_in_threadgroup]])
|
||||
{
|
||||
const uint num_experts = 128;
|
||||
const uint num_active_experts = 4;
|
||||
|
||||
input += gid * (num_experts / 4);
|
||||
output += gid * num_active_experts;
|
||||
|
||||
uint4 idx = tid * 4 + (uint4) {0, 1, 2, 3};
|
||||
float4 val = input[tid];
|
||||
|
||||
const float topval0 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
|
||||
uint idx0 = 0xFFFFFFFFu;
|
||||
if (val.w == topval0) {
|
||||
idx0 = idx.w;
|
||||
}
|
||||
if (val.z == topval0) {
|
||||
idx0 = idx.z;
|
||||
}
|
||||
if (val.y == topval0) {
|
||||
idx0 = idx.y;
|
||||
}
|
||||
if (val.x == topval0) {
|
||||
idx0 = idx.x;
|
||||
}
|
||||
const uint topidx0 = metal::simd_min(idx0);
|
||||
const bool4 is_topidx0 = idx == topidx0;
|
||||
val = metal::select(val, -INFINITY, is_topidx0);
|
||||
idx = metal::select(idx, 0xFFFFFFFFu, is_topidx0);
|
||||
|
||||
const float topval1 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
|
||||
uint idx1 = 0xFFFFFFFFu;
|
||||
if (val.w == topval1) {
|
||||
idx1 = idx.w;
|
||||
}
|
||||
if (val.z == topval1) {
|
||||
idx1 = idx.z;
|
||||
}
|
||||
if (val.y == topval1) {
|
||||
idx1 = idx.y;
|
||||
}
|
||||
if (val.x == topval1) {
|
||||
idx1 = idx.x;
|
||||
}
|
||||
const uint topidx1 = metal::simd_min(idx1);
|
||||
const bool4 is_topidx1 = idx == topidx1;
|
||||
val = metal::select(val, -INFINITY, is_topidx1);
|
||||
idx = metal::select(idx, 0xFFFFFFFFu, is_topidx1);
|
||||
|
||||
const float topval2 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
|
||||
uint idx2 = 0xFFFFFFFFu;
|
||||
if (val.w == topval2) {
|
||||
idx2 = idx.w;
|
||||
}
|
||||
if (val.z == topval2) {
|
||||
idx2 = idx.z;
|
||||
}
|
||||
if (val.y == topval2) {
|
||||
idx2 = idx.y;
|
||||
}
|
||||
if (val.x == topval2) {
|
||||
idx2 = idx.x;
|
||||
}
|
||||
const uint topidx2 = metal::simd_min(idx2);
|
||||
const bool4 is_topidx2 = idx == topidx2;
|
||||
val = metal::select(val, -INFINITY, is_topidx2);
|
||||
idx = metal::select(idx, 0xFFFFFFFFu, is_topidx2);
|
||||
|
||||
const float topval3 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
|
||||
uint idx3 = 0xFFFFFFFFu;
|
||||
if (val.w == topval3) {
|
||||
idx3 = idx.w;
|
||||
}
|
||||
if (val.z == topval3) {
|
||||
idx3 = idx.z;
|
||||
}
|
||||
if (val.y == topval3) {
|
||||
idx3 = idx.y;
|
||||
}
|
||||
if (val.x == topval3) {
|
||||
idx3 = idx.x;
|
||||
}
|
||||
const uint topidx3 = metal::simd_min(idx3);
|
||||
|
||||
if (metal::simd_is_first()) {
|
||||
const float topexp0 = 1.0f;
|
||||
const float topexp1 = metal::precise::exp(topval1 - topval0);
|
||||
const float topexp2 = metal::precise::exp(topval2 - topval0);
|
||||
const float topexp3 = metal::precise::exp(topval3 - topval0);
|
||||
|
||||
const float sum = (topexp0 + topexp1) + (topexp2 + topexp3);
|
||||
const float scale = 1.0 / sum;
|
||||
|
||||
output[0] = (gptoss_expert_prediction) {
|
||||
.expert_id = topidx0,
|
||||
.score = topexp0 * scale,
|
||||
};
|
||||
output[1] = (gptoss_expert_prediction) {
|
||||
.expert_id = topidx1,
|
||||
.score = topexp1 * scale,
|
||||
};
|
||||
output[2] = (gptoss_expert_prediction) {
|
||||
.expert_id = topidx2,
|
||||
.score = topexp2 * scale,
|
||||
};
|
||||
output[3] = (gptoss_expert_prediction) {
|
||||
.expert_id = topidx3,
|
||||
.score = topexp3 * scale,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
[[max_total_threads_per_threadgroup(32)]]
|
||||
kernel void gptoss_f32_topk_softmax_e32_k4(
|
||||
constant gptoss_topk_args& args [[ buffer(0) ]],
|
||||
const device float* input [[ buffer(1) ]],
|
||||
device gptoss_expert_prediction* output [[ buffer(2) ]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_position_in_threadgroup]])
|
||||
{
|
||||
const uint num_experts = 32;
|
||||
const uint num_active_experts = 4;
|
||||
|
||||
input += gid * num_experts;
|
||||
output += gid * num_active_experts;
|
||||
|
||||
float val = input[tid];
|
||||
uint idx = tid;
|
||||
|
||||
const float topval0 = metal::simd_max(val);
|
||||
const uint topidx0 = metal::simd_min(val == topval0 ? idx : 0xFFFFFFFFu);
|
||||
if (idx == topidx0) {
|
||||
val = -INFINITY;
|
||||
idx = 0xFFFFFFFFu;
|
||||
}
|
||||
|
||||
const float topval1 = metal::simd_max(val);
|
||||
const uint topidx1 = metal::simd_min(val == topval1 ? idx : 0xFFFFFFFFu);
|
||||
if (idx == topidx1) {
|
||||
val = -INFINITY;
|
||||
idx = 0xFFFFFFFFu;
|
||||
}
|
||||
|
||||
const float topval2 = metal::simd_max(val);
|
||||
const uint topidx2 = metal::simd_min(val == topval2 ? idx : 0xFFFFFFFFu);
|
||||
if (idx == topidx2) {
|
||||
val = -INFINITY;
|
||||
idx = 0xFFFFFFFFu;
|
||||
}
|
||||
|
||||
const float topval3 = metal::simd_max(val);
|
||||
const uint topidx3 = metal::simd_min(val == topval3 ? idx : 0xFFFFFFFFu);
|
||||
|
||||
if (metal::simd_is_first()) {
|
||||
const float topexp0 = 1.0f;
|
||||
const float topexp1 = metal::precise::exp(topval1 - topval0);
|
||||
const float topexp2 = metal::precise::exp(topval2 - topval0);
|
||||
const float topexp3 = metal::precise::exp(topval3 - topval0);
|
||||
|
||||
const float sum = (topexp0 + topexp1) + (topexp2 + topexp3);
|
||||
const float scale = 1.0 / sum;
|
||||
|
||||
output[0] = (gptoss_expert_prediction) {
|
||||
.expert_id = topidx0,
|
||||
.score = topexp0 * scale,
|
||||
};
|
||||
output[1] = (gptoss_expert_prediction) {
|
||||
.expert_id = topidx1,
|
||||
.score = topexp1 * scale,
|
||||
};
|
||||
output[2] = (gptoss_expert_prediction) {
|
||||
.expert_id = topidx2,
|
||||
.score = topexp2 * scale,
|
||||
};
|
||||
output[3] = (gptoss_expert_prediction) {
|
||||
.expert_id = topidx3,
|
||||
.score = topexp3 * scale,
|
||||
};
|
||||
}
|
||||
}
|
||||
33
gpt_oss/metal/test/bf16-f32-embeddings.cc
Normal file
33
gpt_oss/metal/test/bf16-f32-embeddings.cc
Normal file
@@ -0,0 +1,33 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "embeddings-kernel-tester.hpp"
|
||||
|
||||
|
||||
using gptoss::EmbeddingsKernelTester;
|
||||
|
||||
constexpr std::size_t kThreadgroupSize = 64;
|
||||
|
||||
|
||||
TEST(BF16_F32_EMBEDDINGS, single_token_single_tile) {
|
||||
EmbeddingsKernelTester()
|
||||
.num_channels(kThreadgroupSize)
|
||||
.threadgroup_size(kThreadgroupSize)
|
||||
.TestBF16_F32();
|
||||
}
|
||||
|
||||
TEST(BF16_F32_EMBEDDINGS, single_token_multi_tile) {
|
||||
EmbeddingsKernelTester()
|
||||
.num_channels(kThreadgroupSize * 4 + 16)
|
||||
.threadgroup_size(kThreadgroupSize)
|
||||
.TestBF16_F32();
|
||||
}
|
||||
|
||||
TEST(BF16_F32_EMBEDDINGS, multiple_tokens) {
|
||||
EmbeddingsKernelTester()
|
||||
.num_channels(kThreadgroupSize * 4 + 16)
|
||||
.num_tokens(3)
|
||||
.threadgroup_size(kThreadgroupSize)
|
||||
.TestBF16_F32();
|
||||
}
|
||||
119
gpt_oss/metal/test/embeddings-kernel-tester.hpp
Normal file
119
gpt_oss/metal/test/embeddings-kernel-tester.hpp
Normal file
@@ -0,0 +1,119 @@
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include <internal/datatype.hpp>
|
||||
#include <internal/metal.hpp>
|
||||
#include <internal/metal-kernels.h>
|
||||
|
||||
|
||||
namespace gptoss {
|
||||
|
||||
class EmbeddingsKernelTester {
|
||||
public:
|
||||
EmbeddingsKernelTester() { }
|
||||
|
||||
EmbeddingsKernelTester(const EmbeddingsKernelTester&) = delete;
|
||||
EmbeddingsKernelTester(EmbeddingsKernelTester&&) = delete;
|
||||
EmbeddingsKernelTester& operator=(const EmbeddingsKernelTester&) = delete;
|
||||
EmbeddingsKernelTester& operator=(EmbeddingsKernelTester&&) = delete;
|
||||
|
||||
[[nodiscard]]
|
||||
EmbeddingsKernelTester& num_channels(std::uint32_t num_channels) {
|
||||
num_channels_ = num_channels;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t num_channels() const {
|
||||
return num_channels_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
EmbeddingsKernelTester& num_tokens(std::uint32_t num_tokens) {
|
||||
num_tokens_ = num_tokens;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t num_tokens() const {
|
||||
return num_tokens_;
|
||||
}
|
||||
|
||||
std::uint32_t vocabulary_size() const {
|
||||
return num_tokens() + 1;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
EmbeddingsKernelTester& threadgroup_size(std::size_t threadgroup_size) {
|
||||
threadgroup_size_ = threadgroup_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::size_t threadgroup_size() const {
|
||||
return threadgroup_size_;
|
||||
}
|
||||
|
||||
void Validate() const {
|
||||
ASSERT_NE(num_channels(), 0);
|
||||
ASSERT_NE(num_tokens(), 0);
|
||||
ASSERT_NE(threadgroup_size(), 0);
|
||||
ASSERT_EQ(threadgroup_size() % 32, 0);
|
||||
}
|
||||
|
||||
void TestBF16_F32() const {
|
||||
Validate();
|
||||
|
||||
metal::CommandBuffer command_buffer{command_queue_};
|
||||
metal::Buffer token_buffer{device_, sizeof(std::uint32_t)};
|
||||
metal::Buffer weight_buffer{device_, vocabulary_size() * num_channels() * sizeof(gptoss_bfloat16)};
|
||||
metal::Buffer output_buffer{device_, num_channels() * sizeof(float)};
|
||||
|
||||
std::uint32_t* token_ptr = static_cast<std::uint32_t*>(token_buffer.ptr());
|
||||
for (std::uint32_t t = 0; t < num_tokens(); t++) {
|
||||
token_ptr[t] = t + 1;
|
||||
}
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
|
||||
command_buffer.handle(),
|
||||
bf16_f32_embeddings_fn.handle(),
|
||||
threadgroup_size(),
|
||||
token_buffer.handle(),
|
||||
/*token_offset=*/0,
|
||||
weight_buffer.handle(),
|
||||
/*weight_offset=*/0,
|
||||
output_buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_tokens(),
|
||||
num_channels()),
|
||||
"gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const gptoss_bfloat16* weight_ptr = static_cast<const gptoss_bfloat16*>(weight_buffer.ptr());
|
||||
const float* output_ptr = static_cast<const float*>(output_buffer.ptr());
|
||||
for (std::uint32_t t = 0; t < num_tokens(); t++) {
|
||||
const std::uint32_t token = token_ptr[t];
|
||||
for (std::uint32_t i = 0; i < num_channels(); i++) {
|
||||
const gptoss_bfloat16 input_val = weight_ptr[token * num_channels() + i];
|
||||
const float ref_output = upcast<float>(input_val);
|
||||
const float output = output_ptr[t * num_channels() + i];
|
||||
ASSERT_EQ(output, ref_output)
|
||||
<< "at token " << t << ", position " << i << " / " << num_channels() << ", input " << std::uint32_t(input_val.bits);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
metal::Device device_{};
|
||||
metal::CommandQueue command_queue_{device_};
|
||||
metal::Library library_{device_};
|
||||
metal::Function bf16_f32_embeddings_fn{library_, "gptoss_bf16_f32_embeddings"};
|
||||
std::uint32_t num_tokens_{1};
|
||||
std::uint32_t num_channels_{1};
|
||||
std::size_t threadgroup_size_{32};
|
||||
};
|
||||
|
||||
} // namespace gptoss
|
||||
60
gpt_oss/metal/test/f32-bf16w-matmul.cc
Normal file
60
gpt_oss/metal/test/f32-bf16w-matmul.cc
Normal file
@@ -0,0 +1,60 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include "matmul-kernel-tester.hpp"
|
||||
|
||||
|
||||
using gptoss::MatMulKernelTester;
|
||||
|
||||
constexpr size_t kSimdgroupSize = 32; // fixed in the kernel
|
||||
|
||||
TEST(F32_BF16W_MATMUL, single_simdgroup_single_iteration) {
|
||||
MatMulKernelTester()
|
||||
.num_rows(1)
|
||||
.num_cols(kSimdgroupSize * 4)
|
||||
.threadgroup_size(kSimdgroupSize)
|
||||
.TestF32_BF16W();
|
||||
}
|
||||
|
||||
TEST(F32_BF16W_MATMUL, single_simdgroup_multiple_iteration) {
|
||||
MatMulKernelTester()
|
||||
.num_rows(1)
|
||||
.num_cols((2 * kSimdgroupSize + 1) * 4)
|
||||
.threadgroup_size(kSimdgroupSize)
|
||||
.TestF32_BF16W();
|
||||
}
|
||||
|
||||
TEST(F32_BF16W_MATMUL, single_threadgroup) {
|
||||
constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;
|
||||
|
||||
MatMulKernelTester()
|
||||
.num_rows(threadgroup_size / kSimdgroupSize)
|
||||
.num_cols((2 * kSimdgroupSize + 1) * 4)
|
||||
.threadgroup_size(threadgroup_size)
|
||||
.TestF32_BF16W();
|
||||
}
|
||||
|
||||
TEST(F32_BF16W_MATMUL, multiple_threadgroups) {
|
||||
constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;
|
||||
constexpr std::uint32_t num_threadgroups = 3;
|
||||
|
||||
MatMulKernelTester()
|
||||
.num_rows(num_threadgroups * threadgroup_size / kSimdgroupSize)
|
||||
.num_cols((2 * kSimdgroupSize + 1) * 4)
|
||||
.threadgroup_size(threadgroup_size)
|
||||
.TestF32_BF16W();
|
||||
}
|
||||
|
||||
TEST(F32_BF16W_MATMUL, multiple_tokens) {
|
||||
constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;
|
||||
constexpr std::uint32_t num_threadgroups = 3;
|
||||
|
||||
MatMulKernelTester()
|
||||
.num_rows(num_threadgroups * threadgroup_size / kSimdgroupSize)
|
||||
.num_cols((2 * kSimdgroupSize + 1) * 4)
|
||||
.num_tokens(2)
|
||||
.threadgroup_size(threadgroup_size)
|
||||
.TestF32_BF16W();
|
||||
}
|
||||
36
gpt_oss/metal/test/f32-bf16w-rmsnorm.cc
Normal file
36
gpt_oss/metal/test/f32-bf16w-rmsnorm.cc
Normal file
@@ -0,0 +1,36 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "rmsnorm-kernel-tester.hpp"
|
||||
|
||||
|
||||
using gptoss::RMSNormKernelTester;
|
||||
|
||||
constexpr std::uint32_t kThreadgroupSize = 1024; // fixed in the kernel
|
||||
constexpr std::uint32_t kVectorSize = 4; // fixed in the kernel
|
||||
|
||||
TEST(F32_BF16W_RMSNORM, single_iteration) {
|
||||
RMSNormKernelTester()
|
||||
.num_channels(kThreadgroupSize)
|
||||
.TestF32_BF16W();
|
||||
}
|
||||
|
||||
TEST(F32_BF16W_RMSNORM, multiple_iterations) {
|
||||
RMSNormKernelTester()
|
||||
.num_channels(kThreadgroupSize * 2)
|
||||
.TestF32_BF16W();
|
||||
}
|
||||
|
||||
TEST(F32_BF16W_RMSNORM, partial_iteration) {
|
||||
RMSNormKernelTester()
|
||||
.num_channels(kThreadgroupSize * 2 + kVectorSize)
|
||||
.TestF32_BF16W();
|
||||
}
|
||||
|
||||
TEST(F32_BF16W_RMSNORM, multiple_tokens) {
|
||||
RMSNormKernelTester()
|
||||
.num_tokens(3)
|
||||
.num_channels(kThreadgroupSize * 2 + kVectorSize)
|
||||
.TestF32_BF16W();
|
||||
}
|
||||
230
gpt_oss/metal/test/f32-random.cc
Normal file
230
gpt_oss/metal/test/f32-random.cc
Normal file
@@ -0,0 +1,230 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <internal/metal.hpp>
|
||||
#include <internal/metal-kernels.h>
|
||||
#include <internal/rng.hpp>
|
||||
|
||||
using gptoss::Check;
|
||||
using namespace gptoss::metal;
|
||||
|
||||
|
||||
constexpr uint64_t kSeed = UINT64_C(1019827666124465388);
|
||||
constexpr uint64_t kOffset = UINT64_C(12345678901234567890);
|
||||
constexpr float kMin = -1.0f;
|
||||
constexpr float kMax = +1.5f;
|
||||
constexpr float kScale = (kMax - kMin) * 0.5f;
|
||||
constexpr float kBias = (kMin + kMax) * 0.5f;
|
||||
constexpr size_t kThreadgroupSize = 128;
|
||||
|
||||
TEST(F32_FILL_RANDOM, single_threadgroup_single_iteration) {
|
||||
constexpr size_t num_bytes = kThreadgroupSize * 16;
|
||||
constexpr size_t num_elements = num_bytes / sizeof(uint32_t);
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
Library library{device};
|
||||
Function f32_fill_random_fn{library, "gptoss_f32_fill_random"};
|
||||
Buffer buffer{device, num_elements * sizeof(float)};
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
||||
command_buffer.handle(),
|
||||
f32_fill_random_fn.handle(),
|
||||
/*threadgroup_size=*/kThreadgroupSize,
|
||||
/*max_threadgroups=*/1,
|
||||
/*output_buffer=*/buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_elements, kSeed, kOffset, kMin, kMax),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const float* output_ptr = static_cast<const float*>(buffer.ptr());
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);
|
||||
const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;
|
||||
const float ref_value = std::fma(ref_float, kScale, kBias);
|
||||
ASSERT_EQ(output_ptr[i], ref_value)
|
||||
<< "at position " << i << " / " << num_elements;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(F32_FILL_RANDOM, single_threadgroup_multiple_iterations) {
|
||||
constexpr size_t num_iterations = 3;
|
||||
constexpr size_t num_bytes = num_iterations * kThreadgroupSize * 16;
|
||||
constexpr size_t num_elements = num_bytes / sizeof(uint32_t);
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
Library library{device};
|
||||
Function f32_fill_random_fn{library, "gptoss_f32_fill_random"};
|
||||
Buffer buffer{device, num_elements * sizeof(float)};
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
||||
command_buffer.handle(),
|
||||
f32_fill_random_fn.handle(),
|
||||
/*threadgroup_size=*/kThreadgroupSize,
|
||||
/*max_threadgroups=*/1,
|
||||
/*output_buffer=*/buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_elements, kSeed, kOffset, kMin, kMax),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const float* output_ptr = static_cast<const float*>(buffer.ptr());
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);
|
||||
const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;
|
||||
const float ref_value = std::fma(ref_float, kScale, kBias);
|
||||
ASSERT_EQ(output_ptr[i], ref_value)
|
||||
<< "at position " << i << " / " << num_elements;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(F32_FILL_RANDOM, multiple_threadgroups_multiple_iterations) {
|
||||
constexpr size_t num_iterations = 3;
|
||||
constexpr size_t num_threadgroups = 2;
|
||||
constexpr size_t num_bytes = num_iterations * num_threadgroups * kThreadgroupSize * 16;
|
||||
constexpr size_t num_elements = num_bytes / sizeof(uint32_t);
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
Library library{device};
|
||||
Function f32_fill_random_fn{library, "gptoss_f32_fill_random"};
|
||||
Buffer buffer{device, num_elements * sizeof(float)};
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
||||
command_buffer.handle(),
|
||||
f32_fill_random_fn.handle(),
|
||||
/*threadgroup_size=*/kThreadgroupSize,
|
||||
/*max_threadgroups=*/num_threadgroups,
|
||||
/*output_buffer=*/buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_elements, kSeed, kOffset, kMin, kMax),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const float* output_ptr = static_cast<const float*>(buffer.ptr());
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);
|
||||
const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;
|
||||
const float ref_value = std::fma(ref_float, kScale, kBias);
|
||||
ASSERT_EQ(output_ptr[i], ref_value)
|
||||
<< "at position " << i << " / " << num_elements;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(F32_FILL_RANDOM, excessive_threadgroups) {
|
||||
constexpr size_t num_bytes = kThreadgroupSize * 16;
|
||||
constexpr size_t num_elements = num_bytes / sizeof(uint32_t);
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
Library library{device};
|
||||
Function f32_fill_random_fn{library, "gptoss_f32_fill_random"};
|
||||
Buffer buffer{device, num_elements * sizeof(float)};
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
||||
command_buffer.handle(),
|
||||
f32_fill_random_fn.handle(),
|
||||
/*threadgroup_size=*/kThreadgroupSize,
|
||||
/*max_threadgroups=*/2,
|
||||
/*output_buffer=*/buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_elements, kSeed, kOffset, kMin, kMax),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const float* output_ptr = static_cast<const float*>(buffer.ptr());
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);
|
||||
const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;
|
||||
const float ref_value = std::fma(ref_float, kScale, kBias);
|
||||
ASSERT_EQ(output_ptr[i], ref_value)
|
||||
<< "at position " << i << " / " << num_elements;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(F32_FILL_RANDOM, nonuniform_range) {
|
||||
constexpr size_t num_iterations = 3;
|
||||
constexpr size_t num_threadgroups = 2;
|
||||
constexpr size_t num_bytes = (num_iterations * num_threadgroups + 1) * kThreadgroupSize * 16;
|
||||
constexpr size_t num_elements = num_bytes / sizeof(uint32_t);
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
Library library{device};
|
||||
Function f32_fill_random_fn{library, "gptoss_f32_fill_random"};
|
||||
Buffer buffer{device, num_elements * sizeof(float)};
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
||||
command_buffer.handle(),
|
||||
f32_fill_random_fn.handle(),
|
||||
/*threadgroup_size=*/kThreadgroupSize,
|
||||
/*max_threadgroups=*/num_threadgroups,
|
||||
/*output_buffer=*/buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_elements, kSeed, kOffset, kMin, kMax),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const float* output_ptr = static_cast<const float*>(buffer.ptr());
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);
|
||||
const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;
|
||||
const float ref_value = std::fma(ref_float, kScale, kBias);
|
||||
ASSERT_EQ(output_ptr[i], ref_value)
|
||||
<< "at position " << i << " / " << num_elements;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(F32_FILL_RANDOM, partial_range) {
|
||||
constexpr size_t num_iterations = 3;
|
||||
constexpr size_t num_threadgroups = 2;
|
||||
constexpr size_t num_bytes = (num_iterations * num_threadgroups * kThreadgroupSize + 1) * 16;
|
||||
constexpr size_t num_elements = num_bytes / sizeof(uint32_t);
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
Library library{device};
|
||||
Function f32_fill_random_fn{library, "gptoss_f32_fill_random"};
|
||||
Buffer buffer{device, num_elements * sizeof(float)};
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
||||
command_buffer.handle(),
|
||||
f32_fill_random_fn.handle(),
|
||||
/*threadgroup_size=*/kThreadgroupSize,
|
||||
/*max_threadgroups=*/num_threadgroups,
|
||||
/*output_buffer=*/buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_elements, kSeed, kOffset, kMin, kMax),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const float* output_ptr = static_cast<const float*>(buffer.ptr());
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
const uint32_t ref_word = gptoss::rng::squares32(kOffset + i, kSeed);
|
||||
const float ref_float = static_cast<int32_t>(ref_word) * 0x1.0p-31f;
|
||||
const float ref_value = std::fma(ref_float, kScale, kBias);
|
||||
ASSERT_EQ(output_ptr[i], ref_value)
|
||||
<< "at position " << i << " / " << num_elements;
|
||||
}
|
||||
}
|
||||
71
gpt_oss/metal/test/f32-rope.cc
Normal file
71
gpt_oss/metal/test/f32-rope.cc
Normal file
@@ -0,0 +1,71 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include "rope-kernel-tester.hpp"
|
||||
|
||||
|
||||
using gptoss::RoPEKernelTester;
|
||||
|
||||
constexpr float kFrequencyBase = 50000.0f;
|
||||
constexpr std::uint32_t kHeadDim = 64; // fixed in the kernel
|
||||
constexpr std::uint32_t kTokenOffset = 7;
|
||||
|
||||
|
||||
TEST(F32_ROPE, single_simdgroup) {
|
||||
RoPEKernelTester()
|
||||
.head_dim(kHeadDim)
|
||||
.num_q_heads(1)
|
||||
.num_kv_heads(0)
|
||||
.token_offset(kTokenOffset)
|
||||
.frequency_base(kFrequencyBase)
|
||||
.threadgroup_size(kHeadDim / 2)
|
||||
.TestF32();
|
||||
}
|
||||
|
||||
TEST(F32_ROPE, single_threadgroup) {
|
||||
constexpr std::size_t threadgroup_size = 64;
|
||||
constexpr std::uint32_t num_heads = threadgroup_size / (kHeadDim / 2);
|
||||
|
||||
RoPEKernelTester()
|
||||
.head_dim(kHeadDim)
|
||||
.num_q_heads(num_heads)
|
||||
.num_kv_heads(0)
|
||||
.token_offset(kTokenOffset)
|
||||
.frequency_base(kFrequencyBase)
|
||||
.threadgroup_size(threadgroup_size)
|
||||
.TestF32();
|
||||
}
|
||||
|
||||
TEST(F32_ROPE, multiple_threadgroups) {
|
||||
constexpr std::uint32_t num_threadgroups = 3;
|
||||
constexpr std::size_t threadgroup_size = 64;
|
||||
constexpr std::uint32_t num_heads = num_threadgroups * (threadgroup_size / (kHeadDim / 2));
|
||||
|
||||
RoPEKernelTester()
|
||||
.head_dim(kHeadDim)
|
||||
.num_q_heads(num_heads)
|
||||
.num_kv_heads(0)
|
||||
.token_offset(kTokenOffset)
|
||||
.frequency_base(kFrequencyBase)
|
||||
.threadgroup_size(threadgroup_size)
|
||||
.TestF32();
|
||||
}
|
||||
|
||||
TEST(F32_ROPE, multiple_tokens) {
|
||||
constexpr std::uint32_t num_tokens = 2;
|
||||
constexpr std::uint32_t num_threadgroups = 3;
|
||||
constexpr std::size_t threadgroup_size = 64;
|
||||
constexpr std::uint32_t num_heads = num_threadgroups * (threadgroup_size / (kHeadDim / 2));
|
||||
|
||||
RoPEKernelTester()
|
||||
.head_dim(kHeadDim)
|
||||
.num_tokens(2)
|
||||
.num_q_heads(num_heads)
|
||||
.num_kv_heads(0)
|
||||
.token_offset(kTokenOffset)
|
||||
.frequency_base(kFrequencyBase)
|
||||
.threadgroup_size(threadgroup_size)
|
||||
.TestF32();
|
||||
}
|
||||
101
gpt_oss/metal/test/fill-random-kernel-tester.hpp
Normal file
101
gpt_oss/metal/test/fill-random-kernel-tester.hpp
Normal file
@@ -0,0 +1,101 @@
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include <internal/datatype.hpp>
|
||||
#include <internal/metal.hpp>
|
||||
#include <internal/metal-kernels.h>
|
||||
#include <internal/rng.hpp>
|
||||
|
||||
|
||||
namespace gptoss {
|
||||
|
||||
class FillRandomKernelTester {
|
||||
public:
|
||||
FillRandomKernelTester() { }
|
||||
|
||||
FillRandomKernelTester(const FillRandomKernelTester&) = delete;
|
||||
FillRandomKernelTester(FillRandomKernelTester&&) = delete;
|
||||
FillRandomKernelTester& operator=(const FillRandomKernelTester&) = delete;
|
||||
FillRandomKernelTester& operator=(FillRandomKernelTester&&) = delete;
|
||||
|
||||
[[nodiscard]]
|
||||
FillRandomKernelTester& num_elements(std::uint32_t num_elements) {
|
||||
num_elements_ = num_elements;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t num_elements() const {
|
||||
return num_elements_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
FillRandomKernelTester& threadgroup_size(std::size_t threadgroup_size) {
|
||||
threadgroup_size_ = threadgroup_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::size_t threadgroup_size() const {
|
||||
return threadgroup_size_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
FillRandomKernelTester& max_threadgroups(std::size_t max_threadgroups) {
|
||||
max_threadgroups_ = max_threadgroups;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::size_t max_threadgroups() const {
|
||||
return max_threadgroups_;
|
||||
}
|
||||
|
||||
void Validate() const {
|
||||
ASSERT_NE(num_elements(), 0);
|
||||
ASSERT_NE(threadgroup_size(), 0);
|
||||
ASSERT_NE(max_threadgroups(), 0);
|
||||
}
|
||||
|
||||
void TestU32() const {
|
||||
Validate();
|
||||
|
||||
metal::Buffer output_buffer{device_, num_elements() * sizeof(std::uint32_t)};
|
||||
|
||||
metal::CommandBuffer command_buffer{command_queue_};
|
||||
command_buffer.encode_launch_u32_fill_random(
|
||||
u32_fill_random_fn_,
|
||||
threadgroup_size(),
|
||||
max_threadgroups(),
|
||||
output_buffer,
|
||||
/*output_offset=*/0,
|
||||
num_elements(), kSeed, kOffset);
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const std::uint32_t* output_ptr = static_cast<const std::uint32_t*>(output_buffer.ptr());
|
||||
for (std::size_t i = 0; i < num_elements(); i++) {
|
||||
const std::uint32_t ref_value = gptoss::rng::squares32(kOffset + i, kSeed);
|
||||
ASSERT_EQ(output_ptr[i], ref_value)
|
||||
<< "at position " << i << " / " << num_elements();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr uint64_t kSeed{UINT64_C(1019827666124465388)};
|
||||
static constexpr uint64_t kOffset{UINT64_C(12345678901234567890)};
|
||||
|
||||
metal::Device device_{};
|
||||
metal::CommandQueue command_queue_{device_};
|
||||
metal::Library library_{device_};
|
||||
metal::Function f32_fill_random_fn_{library_, "gptoss_f32_fill_random"};
|
||||
metal::Function bf16_fill_random_fn_{library_, "gptoss_bf16_fill_random"};
|
||||
metal::Function u32_fill_random_fn_{library_, "gptoss_u32_fill_random"};
|
||||
std::uint32_t num_elements_{1};
|
||||
std::size_t threadgroup_size_{32};
|
||||
std::size_t max_threadgroups_{1};
|
||||
};
|
||||
|
||||
} // namespace gptoss
|
||||
164
gpt_oss/metal/test/matmul-kernel-tester.hpp
Normal file
164
gpt_oss/metal/test/matmul-kernel-tester.hpp
Normal file
@@ -0,0 +1,164 @@
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include <internal/datatype.hpp>
|
||||
#include <internal/metal.hpp>
|
||||
#include <internal/metal-kernels.h>
|
||||
|
||||
|
||||
namespace gptoss {
|
||||
|
||||
class MatMulKernelTester {
|
||||
public:
|
||||
MatMulKernelTester() { }
|
||||
|
||||
MatMulKernelTester(const MatMulKernelTester&) = delete;
|
||||
MatMulKernelTester(MatMulKernelTester&&) = delete;
|
||||
MatMulKernelTester& operator=(const MatMulKernelTester&) = delete;
|
||||
MatMulKernelTester& operator=(MatMulKernelTester&&) = delete;
|
||||
|
||||
[[nodiscard]]
|
||||
MatMulKernelTester& num_rows(std::uint32_t num_rows) {
|
||||
num_rows_ = num_rows;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t num_rows() const {
|
||||
return num_rows_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
MatMulKernelTester& num_cols(std::uint32_t num_cols) {
|
||||
num_cols_ = num_cols;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t num_cols() const {
|
||||
return num_cols_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
MatMulKernelTester& num_tokens(std::uint32_t num_tokens) {
|
||||
num_tokens_ = num_tokens;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t num_tokens() const {
|
||||
return num_tokens_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
MatMulKernelTester& threadgroup_size(std::size_t threadgroup_size) {
|
||||
threadgroup_size_ = threadgroup_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::size_t threadgroup_size() const {
|
||||
return threadgroup_size_;
|
||||
}
|
||||
|
||||
void Validate(std::uint32_t vec_size) const {
|
||||
ASSERT_NE(num_rows(), 0);
|
||||
ASSERT_NE(num_cols(), 0);
|
||||
ASSERT_EQ(num_cols() % vec_size, 0);
|
||||
ASSERT_NE(num_tokens(), 0);
|
||||
ASSERT_NE(threadgroup_size(), 0);
|
||||
}
|
||||
|
||||
void TestF32_BF16W() const {
|
||||
Validate(/*vec_size=*/4);
|
||||
|
||||
metal::CommandBuffer command_buffer{command_queue_};
|
||||
metal::Buffer input_buffer{device_, num_tokens() * num_cols() * sizeof(float)};
|
||||
metal::Buffer weight_buffer{device_, num_rows() * num_cols() * sizeof(gptoss_bfloat16)};
|
||||
metal::Buffer bias_buffer{device_, num_rows() * sizeof(gptoss_bfloat16)};
|
||||
metal::Buffer output_buffer{device_, num_tokens() * num_rows() * sizeof(float)};
|
||||
|
||||
command_buffer.encode_launch_f32_fill_random(
|
||||
f32_fill_random_fn_,
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/kFillRandomMaxThreadgroups,
|
||||
/*output_buffer=*/input_buffer,
|
||||
/*output_offset=*/0,
|
||||
num_tokens() * num_cols(), kSeed, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);
|
||||
|
||||
command_buffer.encode_launch_bf16_fill_random(
|
||||
bf16_fill_random_fn_,
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/kFillRandomMaxThreadgroups,
|
||||
/*output_buffer=*/weight_buffer,
|
||||
/*output_offset=*/0,
|
||||
num_rows() * num_cols(), kSeed + 1, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);
|
||||
|
||||
command_buffer.encode_launch_bf16_fill_random(
|
||||
bf16_fill_random_fn_,
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/kFillRandomMaxThreadgroups,
|
||||
/*output_buffer=*/bias_buffer,
|
||||
/*output_offset=*/0,
|
||||
num_rows(), kSeed + 2, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
|
||||
command_buffer.handle(),
|
||||
f32_bf16w_matmul_fn_.handle(),
|
||||
/*threadgroup_size=*/threadgroup_size(),
|
||||
input_buffer.handle(),
|
||||
/*input_offset=*/0,
|
||||
weight_buffer.handle(),
|
||||
/*weight_offset=*/0,
|
||||
bias_buffer.handle(),
|
||||
/*bias_offset=*/0,
|
||||
output_buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_tokens(),
|
||||
num_cols(),
|
||||
num_rows()),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const float* input_ptr = static_cast<const float*>(input_buffer.ptr());
|
||||
const gptoss_bfloat16* weight_ptr = static_cast<const gptoss_bfloat16*>(weight_buffer.ptr());
|
||||
const gptoss_bfloat16* bias_ptr = static_cast<const gptoss_bfloat16*>(bias_buffer.ptr());
|
||||
const float* output_ptr = static_cast<const float*>(output_buffer.ptr());
|
||||
for (size_t t = 0; t < num_tokens(); t++) {
|
||||
for (size_t r = 0; r < num_rows(); r++) {
|
||||
double ref_sum = upcast<double>(bias_ptr[r]);
|
||||
for (size_t c = 0; c < num_cols(); c++) {
|
||||
const double ref_weight = upcast<double>(weight_ptr[r * num_cols() + c]);
|
||||
const double input_value = upcast<double>(input_ptr[t * num_cols() + c]);
|
||||
ref_sum = std::fma(input_value, ref_weight, ref_sum);
|
||||
}
|
||||
ASSERT_NEAR(upcast<double>(output_ptr[t * num_rows() + r]), ref_sum, std::abs(ref_sum) * 1.0e-5)
|
||||
<< "token " << t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr std::uint64_t kSeed{UINT64_C(1019827666124465388)};
|
||||
static constexpr std::size_t kFillRandomMaxThreadgroups = 10;
|
||||
static constexpr float fp4e2m1_to_fp32[16] = {
|
||||
+0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,
|
||||
-0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f,
|
||||
};
|
||||
|
||||
metal::Device device_{};
|
||||
metal::CommandQueue command_queue_{device_};
|
||||
metal::Library library_{device_};
|
||||
metal::Function f32_fill_random_fn_{library_, "gptoss_f32_fill_random"};
|
||||
metal::Function bf16_fill_random_fn_{library_, "gptoss_bf16_fill_random"};
|
||||
metal::Function f32_bf16w_matmul_fn_{library_, "gptoss_f32_bf16w_matmul"};
|
||||
std::uint32_t num_tokens_{1};
|
||||
std::uint32_t num_rows_{1};
|
||||
std::uint32_t num_cols_{32};
|
||||
std::size_t threadgroup_size_{32};
|
||||
};
|
||||
|
||||
} // namespace gptoss
|
||||
135
gpt_oss/metal/test/mf4-f32-convert.cc
Normal file
135
gpt_oss/metal/test/mf4-f32-convert.cc
Normal file
@@ -0,0 +1,135 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <ios>
|
||||
|
||||
#include <internal/metal.hpp>
|
||||
#include <internal/metal-kernels.h>
|
||||
|
||||
using gptoss::Check;
|
||||
using namespace gptoss::metal;
|
||||
|
||||
constexpr size_t kThreadgroupSize = 32;
|
||||
|
||||
|
||||
static float fp4e2m1_to_fp32[16] = {
|
||||
+0.0f, +0.5f, +1.0f, +1.5f, +2.0f, +3.0f, +4.0f, +6.0f,
|
||||
-0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f,
|
||||
};
|
||||
|
||||
TEST(MF4_F32_CONVERT, single_threadgroup_single_iteration) {
|
||||
constexpr size_t num_blocks = kThreadgroupSize;
|
||||
constexpr size_t num_elements = num_blocks * 32;
|
||||
constexpr size_t num_bytes = num_elements / 2;
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
Library library{device};
|
||||
Function mf4_f32_convert_fn{library, "gptoss_mf4_f32_convert"};
|
||||
Buffer block_buffer{device, num_bytes};
|
||||
Buffer scale_buffer{device, num_blocks * sizeof(uint8_t)};
|
||||
Buffer output_buffer{device, num_elements * sizeof(float)};
|
||||
|
||||
uint8_t* block_ptr = static_cast<uint8_t*>(block_buffer.ptr());
|
||||
std::memset(block_ptr, 0, num_bytes);
|
||||
for (size_t b = 0; b < num_blocks; b++) {
|
||||
for (size_t i = 0; i < 32; i++) {
|
||||
const uint8_t nibble = (i + b) & 0x0F;
|
||||
const uint8_t byte = nibble << ((i % 2) * 4);
|
||||
block_ptr[b * 16 + i / 2] |= byte;
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t* scale_ptr = static_cast<uint8_t*>(scale_buffer.ptr());
|
||||
for (size_t b = 0; b < num_blocks; b++) {
|
||||
scale_ptr[b] = 127 - b;
|
||||
}
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
|
||||
command_buffer.handle(),
|
||||
mf4_f32_convert_fn.handle(),
|
||||
/*threadgroup_size=*/kThreadgroupSize,
|
||||
/*max_threadgroups=*/1,
|
||||
block_buffer.handle(),
|
||||
scale_buffer.handle(),
|
||||
output_buffer.handle(),
|
||||
num_elements),
|
||||
"gptoss_metal_command_buffer_encode_launch_mf4_f32_convert");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const float* output_ptr = static_cast<const float*>(output_buffer.ptr());
|
||||
for (size_t b = 0; b < num_blocks; b++) {
|
||||
for (size_t i = 0; i < 32; i++) {
|
||||
const uint8_t byte = block_ptr[b * 16 + i / 2];
|
||||
const uint8_t nibble = (byte >> ((i % 2) * 4)) & 0x0F;
|
||||
const float ref_scale = std::ldexp(1.0f, static_cast<int>(scale_ptr[b]) - 127);
|
||||
const float ref_value = fp4e2m1_to_fp32[nibble] * ref_scale;
|
||||
ASSERT_EQ(output_ptr[b * 32 + i], ref_value)
|
||||
<< "at position " << i << " / 32"
|
||||
<< ", block " << b << " / " << num_blocks
|
||||
<< ", FP4e2m1 value " << std::hex << uint32_t(nibble);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MF4_F32_CONVERT, multiple_threadgroups_multiple_iterations) {
|
||||
constexpr size_t num_threadgroups = 2;
|
||||
constexpr size_t num_blocks = num_threadgroups * (kThreadgroupSize + 1);
|
||||
constexpr size_t num_elements = num_blocks * 32;
|
||||
constexpr size_t num_bytes = num_elements / 2;
|
||||
|
||||
Device device;
|
||||
CommandQueue command_queue{device};
|
||||
CommandBuffer command_buffer{command_queue};
|
||||
Library library{device};
|
||||
Function mf4_f32_convert_fn{library, "gptoss_mf4_f32_convert"};
|
||||
Buffer block_buffer{device, num_bytes};
|
||||
Buffer scale_buffer{device, num_blocks * sizeof(uint8_t)};
|
||||
Buffer output_buffer{device, num_elements * sizeof(float)};
|
||||
|
||||
uint8_t* block_ptr = static_cast<uint8_t*>(block_buffer.ptr());
|
||||
std::memset(block_ptr, 0, num_bytes);
|
||||
for (size_t b = 0; b < num_blocks; b++) {
|
||||
for (size_t i = 0; i < 32; i++) {
|
||||
const uint8_t nibble = (i + b) & 0x0F;
|
||||
const uint8_t byte = nibble << ((i % 2) * 4);
|
||||
block_ptr[b * 16 + i / 2] |= byte;
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t* scale_ptr = static_cast<uint8_t*>(scale_buffer.ptr());
|
||||
for (size_t b = 0; b < num_blocks; b++) {
|
||||
scale_ptr[b] = 200 - b;
|
||||
}
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
|
||||
command_buffer.handle(),
|
||||
mf4_f32_convert_fn.handle(),
|
||||
/*threadgroup_size=*/kThreadgroupSize,
|
||||
/*max_threadgroups=*/num_threadgroups,
|
||||
block_buffer.handle(),
|
||||
scale_buffer.handle(),
|
||||
output_buffer.handle(),
|
||||
num_elements),
|
||||
"gptoss_metal_command_buffer_encode_launch_mf4_f32_convert");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const float* output_ptr = static_cast<const float*>(output_buffer.ptr());
|
||||
for (size_t b = 0; b < num_blocks; b++) {
|
||||
for (size_t i = 0; i < 32; i++) {
|
||||
const uint8_t byte = block_ptr[b * 16 + i / 2];
|
||||
const uint8_t nibble = (byte >> ((i % 2) * 4)) & 0x0F;
|
||||
const float ref_scale = std::ldexp(1.0f, static_cast<int>(scale_ptr[b]) - 127);
|
||||
const float ref_value = fp4e2m1_to_fp32[nibble] * ref_scale;
|
||||
ASSERT_EQ(output_ptr[b * 32 + i], ref_value)
|
||||
<< "at position " << i << " / 32"
|
||||
<< ", block " << b << " / " << num_blocks
|
||||
<< ", FP4e2m1 value " << std::hex << uint32_t(nibble);
|
||||
}
|
||||
}
|
||||
}
|
||||
139
gpt_oss/metal/test/rmsnorm-kernel-tester.hpp
Normal file
139
gpt_oss/metal/test/rmsnorm-kernel-tester.hpp
Normal file
@@ -0,0 +1,139 @@
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include <internal/datatype.hpp>
|
||||
#include <internal/metal.hpp>
|
||||
#include <internal/metal-kernels.h>
|
||||
|
||||
|
||||
namespace gptoss {
|
||||
|
||||
class RMSNormKernelTester {
|
||||
public:
|
||||
RMSNormKernelTester() { }
|
||||
|
||||
RMSNormKernelTester(const RMSNormKernelTester&) = delete;
|
||||
RMSNormKernelTester(RMSNormKernelTester&&) = delete;
|
||||
RMSNormKernelTester& operator=(const RMSNormKernelTester&) = delete;
|
||||
RMSNormKernelTester& operator=(RMSNormKernelTester&&) = delete;
|
||||
|
||||
[[nodiscard]]
|
||||
RMSNormKernelTester& num_channels(std::uint32_t num_channels) {
|
||||
num_channels_ = num_channels;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t num_channels() const {
|
||||
return num_channels_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
RMSNormKernelTester& num_tokens(std::uint32_t num_tokens) {
|
||||
num_tokens_ = num_tokens;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t num_tokens() const {
|
||||
return num_tokens_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
RMSNormKernelTester& epsilon(float epsilon) {
|
||||
epsilon_ = epsilon;
|
||||
return *this;
|
||||
}
|
||||
|
||||
float epsilon() const {
|
||||
return epsilon_;
|
||||
}
|
||||
|
||||
void Validate() const {
|
||||
ASSERT_NE(num_channels(), 0);
|
||||
ASSERT_NE(num_tokens(), 0);
|
||||
ASSERT_GE(epsilon(), 0.0f);
|
||||
}
|
||||
|
||||
void TestF32_BF16W() const {
|
||||
Validate();
|
||||
|
||||
metal::Buffer input_buffer{device_, num_tokens() * num_channels() * sizeof(float)};
|
||||
metal::Buffer weight_buffer{device_, num_channels() * sizeof(gptoss_bfloat16)};
|
||||
metal::Buffer output_buffer{device_, num_tokens() * num_channels() * sizeof(float)};
|
||||
|
||||
metal::CommandBuffer command_buffer{command_queue_};
|
||||
|
||||
command_buffer.encode_launch_f32_fill_random(
|
||||
f32_fill_random_fn_,
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/kFillRandomMaxThreadgroups,
|
||||
/*output_buffer=*/input_buffer, /*output_offset=*/0,
|
||||
num_channels(), kSeed, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);
|
||||
|
||||
command_buffer.encode_launch_bf16_fill_random(
|
||||
bf16_fill_random_fn_,
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/kFillRandomMaxThreadgroups,
|
||||
/*output_buffer=*/weight_buffer, /*output_offset=*/0,
|
||||
num_channels(), kSeed + 1, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
||||
command_buffer.handle(),
|
||||
f32_bf16w_rmsnorm_fn_.handle(),
|
||||
input_buffer.handle(),
|
||||
/*input_offset=*/0,
|
||||
weight_buffer.handle(),
|
||||
/*weight_offset=*/0,
|
||||
output_buffer.handle(),
|
||||
/*output_offset=*/0,
|
||||
num_tokens(),
|
||||
num_channels(),
|
||||
epsilon()),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const float* input_ptr = static_cast<const float*>(input_buffer.ptr());
|
||||
const gptoss_bfloat16* weight_ptr = static_cast<const gptoss_bfloat16*>(weight_buffer.ptr());
|
||||
const float* output_ptr = static_cast<const float*>(output_buffer.ptr());
|
||||
for (std::uint32_t t = 0; t < num_tokens(); t++) {
|
||||
double sumsq = 0.0;
|
||||
for (std::uint32_t c = 0; c < num_channels(); c++) {
|
||||
const double val = static_cast<double>(input_ptr[t * num_channels() + c]);
|
||||
sumsq = std::fma(val, val, sumsq);
|
||||
}
|
||||
const double avgsq = sumsq / static_cast<double>(num_channels());
|
||||
const double scale = 1.0 / std::sqrt(avgsq + epsilon());
|
||||
for (std::uint32_t c = 0; c < num_channels(); c++) {
|
||||
const double input_val = upcast<double>(input_ptr[t * num_channels() + c]);
|
||||
const double weight_val = upcast<double>(weight_ptr[c]);
|
||||
const double ref_output = scale * input_val * weight_val;
|
||||
const double output = upcast<double>(output_ptr[t * num_channels() + c]);
|
||||
ASSERT_NEAR(output, ref_output, 1.0e-5 * std::abs(ref_output))
|
||||
<< "at channel " << c << " / " << num_channels() << ", token " << t << " / " << num_tokens()
|
||||
<< ", input " << input_val << ", weight " << weight_val << ", scale " << scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr std::uint64_t kSeed{UINT64_C(1019827666124465388)};
|
||||
static constexpr std::size_t kFillRandomMaxThreadgroups = 10;
|
||||
|
||||
metal::Device device_{};
|
||||
metal::CommandQueue command_queue_{device_};
|
||||
metal::Library library_{device_};
|
||||
metal::Function f32_fill_random_fn_{library_, "gptoss_f32_fill_random"};
|
||||
metal::Function bf16_fill_random_fn_{library_, "gptoss_bf16_fill_random"};
|
||||
metal::Function f32_bf16w_rmsnorm_fn_{library_, "gptoss_f32_bf16w_rmsnorm"};
|
||||
std::uint32_t num_tokens_{1};
|
||||
std::uint32_t num_channels_{1};
|
||||
float epsilon_{1.0e-5f};
|
||||
};
|
||||
|
||||
} // namespace gptoss
|
||||
204
gpt_oss/metal/test/rope-kernel-tester.hpp
Normal file
204
gpt_oss/metal/test/rope-kernel-tester.hpp
Normal file
@@ -0,0 +1,204 @@
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include <internal/datatype.hpp>
|
||||
#include <internal/metal.hpp>
|
||||
#include <internal/metal-kernels.h>
|
||||
|
||||
|
||||
namespace gptoss {
|
||||
|
||||
class RoPEKernelTester {
|
||||
public:
|
||||
RoPEKernelTester() { }
|
||||
|
||||
RoPEKernelTester(const RoPEKernelTester&) = delete;
|
||||
RoPEKernelTester(RoPEKernelTester&&) = delete;
|
||||
RoPEKernelTester& operator=(const RoPEKernelTester&) = delete;
|
||||
RoPEKernelTester& operator=(RoPEKernelTester&&) = delete;
|
||||
|
||||
[[nodiscard]]
|
||||
RoPEKernelTester& threadgroup_size(std::size_t threadgroup_size) {
|
||||
threadgroup_size_ = threadgroup_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::size_t threadgroup_size() const {
|
||||
return threadgroup_size_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
RoPEKernelTester& head_dim(std::uint32_t head_dim) {
|
||||
head_dim_ = head_dim;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t head_dim() const {
|
||||
return head_dim_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
RoPEKernelTester& num_q_heads(std::uint32_t num_q_heads) {
|
||||
num_q_heads_ = num_q_heads;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t num_q_heads() const {
|
||||
return num_q_heads_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
RoPEKernelTester& num_kv_heads(std::uint32_t num_kv_heads) {
|
||||
num_kv_heads_ = num_kv_heads;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t num_kv_heads() const {
|
||||
return num_kv_heads_;
|
||||
}
|
||||
|
||||
std::uint32_t num_qk_heads() const {
|
||||
return num_q_heads() + num_kv_heads();
|
||||
}
|
||||
|
||||
std::uint32_t num_qkv_heads() const {
|
||||
return num_q_heads() + 2 * num_kv_heads();
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
RoPEKernelTester& num_tokens(std::uint32_t num_tokens) {
|
||||
num_tokens_ = num_tokens;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t num_tokens() const {
|
||||
return num_tokens_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
RoPEKernelTester& token_offset(std::uint32_t token_offset) {
|
||||
token_offset_ = token_offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::uint32_t token_offset() const {
|
||||
return token_offset_;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
RoPEKernelTester& frequency_base(float frequency_base) {
|
||||
frequency_base_ = frequency_base;
|
||||
return *this;
|
||||
}
|
||||
|
||||
float frequency_base() const {
|
||||
return frequency_base_;
|
||||
}
|
||||
|
||||
void Validate() const {
|
||||
ASSERT_NE(head_dim(), 0);
|
||||
ASSERT_EQ(head_dim() % 2, 0);
|
||||
ASSERT_NE(num_q_heads(), 0);
|
||||
ASSERT_NE(num_tokens(), 0);
|
||||
}
|
||||
|
||||
void TestF32() const {
|
||||
Validate();
|
||||
|
||||
metal::Buffer activations_buffer{device_, (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim() * sizeof(float)};
|
||||
metal::Buffer ref_activations_buffer{device_, (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim() * sizeof(float)};
|
||||
|
||||
metal::CommandBuffer command_buffer{command_queue_};
|
||||
|
||||
command_buffer.encode_launch_f32_fill_random(
|
||||
f32_fill_random_fn_,
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/kFillRandomMaxThreadgroups,
|
||||
/*output_buffer=*/activations_buffer,
|
||||
/*output_offset=*/0,
|
||||
(num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim(),
|
||||
kSeed, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);
|
||||
|
||||
command_buffer.encode_launch_f32_fill_random(
|
||||
f32_fill_random_fn_,
|
||||
/*threadgroup_size=*/0,
|
||||
/*max_threadgroups=*/kFillRandomMaxThreadgroups,
|
||||
/*output_buffer=*/ref_activations_buffer,
|
||||
/*output_offset=*/0,
|
||||
(num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim(),
|
||||
kSeed, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);
|
||||
|
||||
Check(gptoss_metal_command_buffer_encode_launch_f32_rope(
|
||||
command_buffer.handle(),
|
||||
f32_rope_fn_.handle(),
|
||||
threadgroup_size(),
|
||||
activations_buffer.handle(),
|
||||
frequency_base(),
|
||||
/*interpolation_scale=*/1.0f,
|
||||
/*yarn_offset=*/0.0f,
|
||||
/*yarn_scale=*/1.0f,
|
||||
/*yarn_multiplier=*/1.0f,
|
||||
/*num_tokens=*/num_tokens(),
|
||||
/*num_q_heads=*/num_q_heads(),
|
||||
/*num_kv_heads=*/num_kv_heads(),
|
||||
head_dim(),
|
||||
/*token_offset=*/token_offset()),
|
||||
"gptoss_metal_command_buffer_encode_launch_f32_rope");
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_completion();
|
||||
|
||||
const float* ref_activations_ptr = static_cast<const float*>(ref_activations_buffer.ptr());
|
||||
const float* activations_ptr = static_cast<const float*>(activations_buffer.ptr());
|
||||
for (std::uint32_t t = 0; t < num_tokens(); t++) {
|
||||
for (std::uint32_t h = 0; h < num_qk_heads(); h++) {
|
||||
for (std::uint32_t d = 0; d < head_dim(); d += 2) {
|
||||
const double inv_freq = 1.0 /
|
||||
std::pow(static_cast<double>(frequency_base()), static_cast<double>(d) / static_cast<double>(head_dim()));
|
||||
const double phi = static_cast<double>(t + token_offset()) * inv_freq;
|
||||
const double cos_phi = std::cos(phi);
|
||||
const double sin_phi = std::sin(phi);
|
||||
const double real = static_cast<double>(ref_activations_ptr[(t * num_qkv_heads() + h) * head_dim() + d]);
|
||||
const double imag = static_cast<double>(ref_activations_ptr[(t * num_qkv_heads() + h) * head_dim() + d + 1]);
|
||||
const double ref_real = real * cos_phi - imag * sin_phi;
|
||||
const double ref_imag = real * sin_phi + imag * cos_phi;
|
||||
ASSERT_NEAR(
|
||||
static_cast<double>(activations_ptr[(t * num_qkv_heads() + h) * head_dim() + d]),
|
||||
ref_real,
|
||||
std::abs(ref_real) * 1.0e-4)
|
||||
<< "at token " << t << " / " << num_tokens();
|
||||
ASSERT_NEAR(
|
||||
static_cast<double>(activations_ptr[(t * num_qkv_heads() + h) * head_dim() + d + 1]),
|
||||
ref_imag,
|
||||
std::abs(ref_imag) * 1.0e-4)
|
||||
<< "at token " << t << " / " << num_tokens();
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr uint64_t kSeed{UINT64_C(1019827666124465388)};
|
||||
static constexpr std::size_t kFillRandomMaxThreadgroups = 10;
|
||||
|
||||
metal::Device device_{};
|
||||
metal::CommandQueue command_queue_{device_};
|
||||
metal::Library library_{device_};
|
||||
metal::Function f32_fill_random_fn_{library_, "gptoss_f32_fill_random"};
|
||||
metal::Function f32_rope_fn_{library_, "gptoss_f32_rope"};
|
||||
std::size_t threadgroup_size_{32};
|
||||
std::uint32_t head_dim_{64};
|
||||
std::uint32_t num_q_heads_{1};
|
||||
std::uint32_t num_kv_heads_{0};
|
||||
std::uint32_t num_tokens_{1};
|
||||
std::uint32_t token_offset_{0};
|
||||
float frequency_base_{50000.0f};
|
||||
};
|
||||
|
||||
} // namespace gptoss
|
||||
70
gpt_oss/metal/test/u32-random.cc
Normal file
70
gpt_oss/metal/test/u32-random.cc
Normal file
@@ -0,0 +1,70 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include "fill-random-kernel-tester.hpp"
|
||||
|
||||
|
||||
using gptoss::FillRandomKernelTester;
|
||||
|
||||
constexpr std::size_t kThreadgroupSize = 128;
|
||||
|
||||
TEST(U32_FILL_RANDOM, single_threadgroup_single_iteration) {
|
||||
FillRandomKernelTester()
|
||||
.num_elements(kThreadgroupSize)
|
||||
.threadgroup_size(kThreadgroupSize)
|
||||
.max_threadgroups(1)
|
||||
.TestU32();
|
||||
}
|
||||
|
||||
TEST(U32_FILL_RANDOM, single_threadgroup_multiple_iterations) {
|
||||
constexpr std::size_t num_iterations = 3;
|
||||
|
||||
FillRandomKernelTester()
|
||||
.num_elements(num_iterations * kThreadgroupSize)
|
||||
.threadgroup_size(kThreadgroupSize)
|
||||
.max_threadgroups(1)
|
||||
.TestU32();
|
||||
}
|
||||
|
||||
TEST(U32_FILL_RANDOM, multiple_threadgroups_multiple_iterations) {
|
||||
constexpr std::size_t num_iterations = 3;
|
||||
constexpr std::size_t num_threadgroups = 2;
|
||||
|
||||
FillRandomKernelTester()
|
||||
.num_elements(num_iterations * num_threadgroups * kThreadgroupSize)
|
||||
.threadgroup_size(kThreadgroupSize)
|
||||
.max_threadgroups(num_threadgroups)
|
||||
.TestU32();
|
||||
}
|
||||
|
||||
TEST(U32_FILL_RANDOM, excessive_threadgroups) {
|
||||
FillRandomKernelTester()
|
||||
.num_elements(kThreadgroupSize)
|
||||
.threadgroup_size(kThreadgroupSize)
|
||||
.max_threadgroups(2)
|
||||
.TestU32();
|
||||
}
|
||||
|
||||
TEST(U32_FILL_RANDOM, nonuniform_range) {
|
||||
constexpr std::size_t num_iterations = 3;
|
||||
constexpr std::size_t num_threadgroups = 2;
|
||||
|
||||
FillRandomKernelTester()
|
||||
.num_elements((num_iterations * num_threadgroups + 1) * kThreadgroupSize)
|
||||
.threadgroup_size(kThreadgroupSize)
|
||||
.max_threadgroups(num_threadgroups)
|
||||
.TestU32();
|
||||
}
|
||||
|
||||
TEST(U32_FILL_RANDOM, partial_range) {
|
||||
constexpr std::size_t num_iterations = 3;
|
||||
constexpr std::size_t num_threadgroups = 2;
|
||||
|
||||
FillRandomKernelTester()
|
||||
.num_elements(num_iterations * num_threadgroups * kThreadgroupSize + 1)
|
||||
.threadgroup_size(kThreadgroupSize)
|
||||
.max_threadgroups(num_threadgroups)
|
||||
.TestU32();
|
||||
}
|
||||
0
gpt_oss/responses_api/__init__.py
Normal file
0
gpt_oss/responses_api/__init__.py
Normal file
915
gpt_oss/responses_api/api_server.py
Normal file
915
gpt_oss/responses_api/api_server.py
Normal file
@@ -0,0 +1,915 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Callable, Literal, Optional
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai_harmony import (
|
||||
Author,
|
||||
Conversation,
|
||||
DeveloperContent,
|
||||
HarmonyEncoding,
|
||||
Message,
|
||||
ReasoningEffort,
|
||||
Role,
|
||||
StreamableParser,
|
||||
StreamState,
|
||||
SystemContent,
|
||||
ToolDescription,
|
||||
)
|
||||
|
||||
from gpt_oss.tools.simple_browser import SimpleBrowserTool
|
||||
from gpt_oss.tools.simple_browser.backend import ExaBackend
|
||||
|
||||
from .events import (
|
||||
ResponseCompletedEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseEvent,
|
||||
ResponseOutputItemAdded,
|
||||
ResponseOutputItemDone,
|
||||
ResponseContentPartAdded,
|
||||
ResponseContentPartDone,
|
||||
ResponseOutputTextDone,
|
||||
ResponseOutputTextDelta,
|
||||
ResponseReasoningTextDone,
|
||||
ResponseReasoningTextDelta,
|
||||
ResponseWebSearchCallInProgress,
|
||||
ResponseWebSearchCallSearching,
|
||||
ResponseWebSearchCallCompleted,
|
||||
ResponseOutputTextAnnotationAdded
|
||||
)
|
||||
from .types import (
|
||||
UrlCitation,
|
||||
Error,
|
||||
FunctionCallItem,
|
||||
Item,
|
||||
ReasoningItem,
|
||||
ReasoningTextContentItem,
|
||||
ResponseObject,
|
||||
ResponsesRequest,
|
||||
TextContentItem,
|
||||
Usage,
|
||||
WebSearchCallItem,
|
||||
WebSearchActionSearch,
|
||||
WebSearchActionOpenPage,
|
||||
WebSearchActionFind,
|
||||
)
|
||||
|
||||
DEFAULT_TEMPERATURE = 0.0
|
||||
|
||||
|
||||
def get_reasoning_effort(effort: Literal["low", "medium", "high"]) -> ReasoningEffort:
|
||||
if effort == "low":
|
||||
return ReasoningEffort.LOW
|
||||
elif effort == "medium":
|
||||
return ReasoningEffort.MEDIUM
|
||||
elif effort == "high":
|
||||
return ReasoningEffort.HIGH
|
||||
|
||||
|
||||
def is_not_builtin_tool(recipient: str) -> bool:
|
||||
return not recipient.startswith("browser.") and not recipient == "python" and not recipient == "assistant"
|
||||
|
||||
def create_api_server(
|
||||
infer_next_token: Callable[[list[int], float], int], encoding: HarmonyEncoding
|
||||
) -> FastAPI:
|
||||
app = FastAPI()
|
||||
responses_store: dict[str, tuple[ResponsesRequest, ResponseObject]] = {}
|
||||
|
||||
def generate_response(
|
||||
input_tokens: list[int],
|
||||
output_tokens: list[int],
|
||||
request_body: ResponsesRequest,
|
||||
debug_mode: bool = False,
|
||||
function_call_ids: Optional[list[tuple[str, str]]] = None,
|
||||
response_id: Optional[str] = None,
|
||||
previous_response_id: Optional[str] = None,
|
||||
browser_tool: Optional[SimpleBrowserTool] = None,
|
||||
browser_call_ids: Optional[list[str]] = None,
|
||||
) -> ResponseObject:
|
||||
output = []
|
||||
error = None
|
||||
if len(output_tokens) > 0:
|
||||
if debug_mode:
|
||||
try:
|
||||
entries = encoding.parse_messages_from_completion_tokens(
|
||||
output_tokens, Role.ASSISTANT
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error parsing tokens: {e}")
|
||||
error = Error(
|
||||
code="invalid_function_call",
|
||||
message=f"{e}",
|
||||
)
|
||||
entries = []
|
||||
else:
|
||||
entries = encoding.parse_messages_from_completion_tokens(
|
||||
output_tokens, Role.ASSISTANT
|
||||
)
|
||||
|
||||
fc_index = 0
|
||||
browser_tool_index = 0
|
||||
for entry in entries:
|
||||
entry_dict = entry.to_dict()
|
||||
if len(entry_dict.get("recipient", "")) > 0 and is_not_builtin_tool(entry_dict["recipient"]):
|
||||
call = entry_dict["content"][0]
|
||||
arguments = call["text"]
|
||||
name = entry_dict["recipient"]
|
||||
|
||||
if name.startswith("functions."):
|
||||
name = name[len("functions.") :]
|
||||
if function_call_ids and fc_index < len(function_call_ids):
|
||||
fc_id, call_id = function_call_ids[fc_index]
|
||||
else:
|
||||
fc_id, call_id = (
|
||||
f"fc_{uuid.uuid4().hex}",
|
||||
f"call_{uuid.uuid4().hex}",
|
||||
)
|
||||
fc_index += 1
|
||||
output.append(
|
||||
FunctionCallItem(
|
||||
type="function_call",
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
id=fc_id,
|
||||
call_id=call_id,
|
||||
)
|
||||
)
|
||||
elif len(entry_dict.get("recipient", "")) > 0 and entry_dict["recipient"].startswith("browser.") and browser_tool is not None:
|
||||
# Mirror event-based creation of WebSearchCallItems when the browser tool is invoked
|
||||
name = entry_dict["recipient"]
|
||||
call = entry_dict["content"][0]
|
||||
arguments = call["text"]
|
||||
function_name = name[len("browser."):]
|
||||
|
||||
# Reconstruct a Message for argument parsing
|
||||
tool_msg = (
|
||||
Message.from_role_and_content(Role.ASSISTANT, arguments)
|
||||
.with_recipient(name)
|
||||
.with_channel("analysis")
|
||||
)
|
||||
|
||||
action = None
|
||||
try:
|
||||
parsed_args = browser_tool.process_arguments(tool_msg)
|
||||
if function_name == "search":
|
||||
action = WebSearchActionSearch(
|
||||
type="search",
|
||||
query=parsed_args["query"],
|
||||
)
|
||||
elif function_name == "open":
|
||||
action = WebSearchActionOpenPage(
|
||||
type="open_page",
|
||||
url=parsed_args["url"],
|
||||
)
|
||||
elif function_name == "find":
|
||||
action = WebSearchActionFind(
|
||||
type="find",
|
||||
pattern=parsed_args["pattern"],
|
||||
url=parsed_args["url"],
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error processing browser tool arguments: {e}")
|
||||
action = None
|
||||
|
||||
if action is not None:
|
||||
if browser_call_ids and browser_tool_index < len(browser_call_ids):
|
||||
web_search_call_id = browser_call_ids[browser_tool_index]
|
||||
else:
|
||||
web_search_call_id = f"ws_{uuid.uuid4().hex}"
|
||||
browser_tool_index += 1
|
||||
output.append(
|
||||
WebSearchCallItem(
|
||||
type="web_search_call",
|
||||
id=web_search_call_id,
|
||||
action=action,
|
||||
)
|
||||
)
|
||||
elif entry_dict["channel"] == "final":
|
||||
content = []
|
||||
for content_entry in entry_dict["content"]:
|
||||
if browser_tool:
|
||||
text_content, annotation_entries, _has_partial_citations = browser_tool.normalize_citations(content_entry["text"])
|
||||
annotations = [UrlCitation(**a) for a in annotation_entries]
|
||||
else:
|
||||
text_content = content_entry["text"]
|
||||
annotations = []
|
||||
|
||||
content.append(
|
||||
TextContentItem(
|
||||
type="output_text",
|
||||
text=text_content,
|
||||
annotations=annotations,
|
||||
)
|
||||
)
|
||||
|
||||
output.append(
|
||||
Item(
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=content,
|
||||
status="completed",
|
||||
)
|
||||
)
|
||||
elif entry_dict["channel"] == "analysis":
|
||||
summary = []
|
||||
content = [
|
||||
ReasoningTextContentItem(
|
||||
type="reasoning_text",
|
||||
text=entry["text"],
|
||||
)
|
||||
for entry in entry_dict["content"]
|
||||
]
|
||||
output.append(
|
||||
ReasoningItem(
|
||||
type="reasoning",
|
||||
summary=summary,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
else:
|
||||
output = []
|
||||
|
||||
usage = (
|
||||
Usage(
|
||||
input_tokens=len(input_tokens),
|
||||
output_tokens=len(output_tokens),
|
||||
total_tokens=len(input_tokens) + len(output_tokens),
|
||||
)
|
||||
if len(output_tokens) > 0
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
debug_str = encoding.decode_utf8(input_tokens + output_tokens)
|
||||
except Exception:
|
||||
debug_str = input_tokens + output_tokens
|
||||
try:
|
||||
debug_input_str = encoding.decode_utf8(input_tokens)
|
||||
except Exception:
|
||||
debug_input_str = input_tokens
|
||||
try:
|
||||
debug_output_str = encoding.decode_utf8(output_tokens)
|
||||
except Exception:
|
||||
debug_output_str = output_tokens
|
||||
|
||||
metadata = (
|
||||
{
|
||||
"__debug": debug_str,
|
||||
"__debug_input": debug_input_str,
|
||||
"__debug_output": debug_output_str,
|
||||
}
|
||||
if debug_mode
|
||||
else {}
|
||||
)
|
||||
|
||||
return ResponseObject(
|
||||
created_at=int(datetime.datetime.now().timestamp()),
|
||||
status="completed",
|
||||
output=output,
|
||||
text={"format": {"type": "text"}},
|
||||
usage=usage,
|
||||
max_output_tokens=request_body.max_output_tokens,
|
||||
error=error,
|
||||
metadata=metadata,
|
||||
id=response_id,
|
||||
previous_response_id=previous_response_id,
|
||||
)
|
||||
|
||||
class StreamResponsesEvents:
|
||||
initial_tokens: list[int]
|
||||
tokens: list[int]
|
||||
output_tokens: list[int]
|
||||
output_text: str
|
||||
request_body: ResponsesRequest
|
||||
request: Request
|
||||
sequence_number: int
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_tokens,
|
||||
request_body: ResponsesRequest,
|
||||
as_sse: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
response_id: Optional[str] = None,
|
||||
store_callback: Optional[
|
||||
Callable[[str, ResponsesRequest, ResponseObject], None]
|
||||
] = None,
|
||||
browser_tool: Optional[SimpleBrowserTool] = None,
|
||||
):
|
||||
self.initial_tokens = initial_tokens
|
||||
self.tokens = initial_tokens.copy()
|
||||
self.output_tokens = []
|
||||
self.output_text = ""
|
||||
self.request_body = request_body
|
||||
self.parser = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
self.as_sse = as_sse
|
||||
self.debug_mode = request_body.metadata.get(
|
||||
"__debug", False
|
||||
) # we use this for demo purposes
|
||||
# Set temperature for this stream, fallback to DEFAULT_TEMPERATURE if not set
|
||||
self.temperature = (
|
||||
request_body.temperature
|
||||
if request_body.temperature is not None
|
||||
else DEFAULT_TEMPERATURE
|
||||
)
|
||||
self.request = request
|
||||
self.sequence_number = 0
|
||||
self.function_call_ids: list[tuple[str, str]] = []
|
||||
self.response_id = response_id
|
||||
self.store_callback = store_callback
|
||||
self.new_request = True
|
||||
self.browser_tool = browser_tool
|
||||
self.use_browser_tool = browser_tool is not None
|
||||
self.browser_call_ids: list[str] = []
|
||||
|
||||
def _send_event(self, event: ResponseEvent):
|
||||
event.sequence_number = self.sequence_number
|
||||
self.sequence_number += 1
|
||||
if self.as_sse:
|
||||
return f"event: {event.type}\ndata: {event.model_dump_json(indent=None)}\n\n"
|
||||
else:
|
||||
return event
|
||||
|
||||
async def run(self):
|
||||
browser_tool = self.browser_tool
|
||||
self.new_request = True
|
||||
initial_response = generate_response(
|
||||
self.initial_tokens,
|
||||
self.output_tokens,
|
||||
self.request_body,
|
||||
function_call_ids=self.function_call_ids,
|
||||
response_id=self.response_id,
|
||||
previous_response_id=self.request_body.previous_response_id,
|
||||
)
|
||||
initial_response.status = "in_progress"
|
||||
yield self._send_event(
|
||||
ResponseCreatedEvent(
|
||||
type="response.created",
|
||||
response=initial_response,
|
||||
)
|
||||
)
|
||||
yield self._send_event(
|
||||
ResponseInProgressEvent(
|
||||
type="response.in_progress",
|
||||
response=initial_response,
|
||||
)
|
||||
)
|
||||
|
||||
current_content_index = (
|
||||
0 # for this implementation we will always have one content item only
|
||||
)
|
||||
current_output_index = -1
|
||||
sent_output_item_added = False
|
||||
|
||||
# we use this if the model outputs a citation to buffer until completed
|
||||
output_delta_buffer = ""
|
||||
# we use this to track the current output text content for things like providing the right indices in citations
|
||||
current_output_text_content = ""
|
||||
current_annotations = []
|
||||
|
||||
while True:
|
||||
# Check for client disconnect
|
||||
if self.request is not None and await self.request.is_disconnected():
|
||||
print("Client disconnected, stopping token generation.")
|
||||
break
|
||||
next_tok = infer_next_token(
|
||||
self.tokens,
|
||||
temperature=self.temperature,
|
||||
new_request=self.new_request,
|
||||
)
|
||||
self.new_request = False
|
||||
self.tokens.append(next_tok)
|
||||
try:
|
||||
self.parser.process(next_tok)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if self.parser.state == StreamState.EXPECT_START:
|
||||
current_output_index += 1
|
||||
sent_output_item_added = False
|
||||
|
||||
if len(self.parser.messages) > 0:
|
||||
previous_item = self.parser.messages[-1]
|
||||
if previous_item.recipient is not None:
|
||||
recipient = previous_item.recipient
|
||||
if (
|
||||
not recipient.startswith("browser.")
|
||||
and not recipient == "python"
|
||||
):
|
||||
fc_id = f"fc_{uuid.uuid4().hex}"
|
||||
call_id = f"call_{uuid.uuid4().hex}"
|
||||
self.function_call_ids.append((fc_id, call_id))
|
||||
yield self._send_event(
|
||||
ResponseOutputItemDone(
|
||||
type="response.output_item.done",
|
||||
output_index=current_output_index,
|
||||
item=FunctionCallItem(
|
||||
type="function_call",
|
||||
name=(
|
||||
previous_item.recipient[
|
||||
len("functions.") :
|
||||
]
|
||||
if previous_item.recipient.startswith(
|
||||
"functions."
|
||||
)
|
||||
else previous_item.recipient
|
||||
),
|
||||
arguments=previous_item.content[0].text,
|
||||
id=fc_id,
|
||||
call_id=call_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
if previous_item.channel == "analysis":
|
||||
yield self._send_event(
|
||||
ResponseReasoningTextDone(
|
||||
type="response.reasoning_text.done",
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
text=previous_item.content[0].text,
|
||||
)
|
||||
)
|
||||
yield self._send_event(
|
||||
ResponseContentPartDone(
|
||||
type="response.content_part.done",
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
part=ReasoningTextContentItem(
|
||||
type="reasoning_text",
|
||||
text=previous_item.content[0].text,
|
||||
),
|
||||
)
|
||||
)
|
||||
yield self._send_event(
|
||||
ResponseOutputItemDone(
|
||||
type="response.output_item.done",
|
||||
output_index=current_output_index,
|
||||
item=ReasoningItem(
|
||||
type="reasoning",
|
||||
summary=[],
|
||||
content=[
|
||||
ReasoningTextContentItem(
|
||||
type="reasoning_text",
|
||||
text=previous_item.content[0].text,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
if previous_item.channel == "final":
|
||||
annotations = [UrlCitation(**a) for a in current_annotations]
|
||||
if browser_tool:
|
||||
normalized_text, _annotations, _has_partial_citations = browser_tool.normalize_citations(previous_item.content[0].text)
|
||||
else:
|
||||
normalized_text = previous_item.content[0].text
|
||||
annotations = []
|
||||
text_content = TextContentItem(
|
||||
type="output_text",
|
||||
text=normalized_text,
|
||||
annotations=annotations,
|
||||
)
|
||||
yield self._send_event(
|
||||
ResponseOutputTextDone(
|
||||
type="response.output_text.done",
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
text=normalized_text,
|
||||
)
|
||||
)
|
||||
yield self._send_event(
|
||||
ResponseContentPartDone(
|
||||
type="response.content_part.done",
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
part=text_content,
|
||||
)
|
||||
)
|
||||
yield self._send_event(
|
||||
ResponseOutputItemDone(
|
||||
type="response.output_item.done",
|
||||
output_index=current_output_index,
|
||||
item=Item(
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[text_content],
|
||||
),
|
||||
)
|
||||
)
|
||||
current_annotations = []
|
||||
current_output_text_content = ""
|
||||
|
||||
if (
|
||||
self.parser.last_content_delta
|
||||
and self.parser.current_channel == "final"
|
||||
and self.parser.current_recipient is None
|
||||
):
|
||||
if not sent_output_item_added:
|
||||
sent_output_item_added = True
|
||||
yield self._send_event(
|
||||
ResponseOutputItemAdded(
|
||||
type="response.output_item.added",
|
||||
output_index=current_output_index,
|
||||
item=Item(type="message", role="assistant", content=[]),
|
||||
)
|
||||
)
|
||||
yield self._send_event(
|
||||
ResponseContentPartAdded(
|
||||
type="response.content_part.added",
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
part=TextContentItem(type="output_text", text=""),
|
||||
)
|
||||
)
|
||||
|
||||
output_delta_buffer += self.parser.last_content_delta
|
||||
should_send_output_text_delta = True
|
||||
if browser_tool:
|
||||
# we normalize on the full current text to get the right indices in citations
|
||||
updated_output_text, annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text_content + output_delta_buffer)
|
||||
# remove the current text to get back the delta but now normalized
|
||||
output_delta_buffer = updated_output_text[len(current_output_text_content):]
|
||||
|
||||
# Filter annotations to only include those whose start_index is not already present in current_annotations
|
||||
# this is to avoid sending duplicate annotations as multiple annotations can't be in the same place
|
||||
existing_start_indices = {a["start_index"] for a in current_annotations}
|
||||
new_annotations = [a for a in annotations if a["start_index"] not in existing_start_indices]
|
||||
for a in new_annotations:
|
||||
current_annotations.append(a)
|
||||
citation = UrlCitation(**a)
|
||||
yield self._send_event(
|
||||
ResponseOutputTextAnnotationAdded(
|
||||
type="response.output_text.annotation.added",
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
annotation_index=len(current_annotations),
|
||||
annotation=citation,
|
||||
)
|
||||
)
|
||||
|
||||
if has_partial_citations:
|
||||
should_send_output_text_delta = False
|
||||
|
||||
|
||||
if should_send_output_text_delta:
|
||||
yield self._send_event(
|
||||
ResponseOutputTextDelta(
|
||||
type="response.output_text.delta",
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
delta=output_delta_buffer,
|
||||
)
|
||||
)
|
||||
current_output_text_content += output_delta_buffer
|
||||
output_delta_buffer = ""
|
||||
|
||||
if (
|
||||
self.parser.last_content_delta
|
||||
and self.parser.current_channel == "analysis"
|
||||
and self.parser.current_recipient is None
|
||||
):
|
||||
if not sent_output_item_added:
|
||||
sent_output_item_added = True
|
||||
yield self._send_event(
|
||||
ResponseOutputItemAdded(
|
||||
type="response.output_item.added",
|
||||
output_index=current_output_index,
|
||||
item=ReasoningItem(
|
||||
type="reasoning", summary=[], content=[]
|
||||
),
|
||||
)
|
||||
)
|
||||
yield self._send_event(
|
||||
ResponseContentPartAdded(
|
||||
type="response.content_part.added",
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
part=ReasoningTextContentItem(type="reasoning_text", text=""),
|
||||
)
|
||||
)
|
||||
yield self._send_event(
|
||||
ResponseReasoningTextDelta(
|
||||
type="response.reasoning_text.delta",
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
delta=self.parser.last_content_delta,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# purely for debugging purposes
|
||||
output_token_text = encoding.decode_utf8([next_tok])
|
||||
self.output_text += output_token_text
|
||||
print(output_token_text, end="", flush=True)
|
||||
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if next_tok in encoding.stop_tokens_for_assistant_actions():
|
||||
if len(self.parser.messages) > 0:
|
||||
last_message = self.parser.messages[-1]
|
||||
if (
|
||||
self.use_browser_tool
|
||||
and last_message.recipient is not None
|
||||
and last_message.recipient.startswith("browser.")
|
||||
):
|
||||
function_name = last_message.recipient[len("browser."):]
|
||||
action = None
|
||||
parsed_args = browser_tool.process_arguments(last_message)
|
||||
if function_name == "search":
|
||||
action = WebSearchActionSearch(
|
||||
type="search",
|
||||
query=parsed_args["query"],
|
||||
)
|
||||
elif function_name == "open":
|
||||
action = WebSearchActionOpenPage(
|
||||
type="open_page",
|
||||
url=parsed_args["url"] if "url" in parsed_args else None,
|
||||
)
|
||||
elif function_name == "find":
|
||||
action = WebSearchActionFind(
|
||||
type="find",
|
||||
pattern=parsed_args["pattern"],
|
||||
url=parsed_args["url"] if "url" in parsed_args else None,
|
||||
)
|
||||
|
||||
if action is not None:
|
||||
web_search_call_id = f"ws_{uuid.uuid4().hex}"
|
||||
self.browser_call_ids.append(web_search_call_id)
|
||||
yield self._send_event(ResponseOutputItemAdded(
|
||||
type="response.output_item.added",
|
||||
output_index=current_output_index,
|
||||
item=WebSearchCallItem(
|
||||
type="web_search_call",
|
||||
id=web_search_call_id,
|
||||
action=action,
|
||||
),
|
||||
))
|
||||
yield self._send_event(
|
||||
ResponseWebSearchCallInProgress(
|
||||
type="response.web_search_call.in_progress",
|
||||
output_index=current_output_index,
|
||||
id=web_search_call_id
|
||||
)
|
||||
)
|
||||
|
||||
async def run_tool():
|
||||
results = []
|
||||
async for msg in browser_tool.process(last_message):
|
||||
results.append(msg)
|
||||
return results
|
||||
|
||||
yield self._send_event(
|
||||
ResponseWebSearchCallSearching(
|
||||
type="response.web_search_call.searching",
|
||||
output_index=current_output_index,
|
||||
id=web_search_call_id,
|
||||
)
|
||||
)
|
||||
result = await run_tool()
|
||||
|
||||
new_tokens = encoding.render_conversation_for_completion(
|
||||
Conversation.from_messages(result), Role.ASSISTANT
|
||||
)
|
||||
|
||||
print(encoding.decode_utf8(new_tokens))
|
||||
self.output_tokens.append(next_tok)
|
||||
self.tokens.append(encoding.encode('<|end|>', allowed_special="all")[0])
|
||||
|
||||
for token in new_tokens:
|
||||
self.parser.process(token)
|
||||
self.output_tokens.append(token)
|
||||
self.tokens.append(token)
|
||||
|
||||
yield self._send_event(
|
||||
ResponseWebSearchCallCompleted(
|
||||
type="response.web_search_call.completed",
|
||||
output_index=current_output_index,
|
||||
id=web_search_call_id,
|
||||
)
|
||||
)
|
||||
yield self._send_event(ResponseOutputItemDone(
|
||||
type="response.output_item.done",
|
||||
output_index=current_output_index,
|
||||
item=WebSearchCallItem(
|
||||
type="web_search_call",
|
||||
id=web_search_call_id,
|
||||
action=action,
|
||||
),
|
||||
))
|
||||
|
||||
current_output_index += 1
|
||||
self.new_request = True
|
||||
|
||||
continue
|
||||
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise ValueError("No messages to process")
|
||||
if len(self.output_tokens) >= self.request_body.max_output_tokens:
|
||||
break
|
||||
|
||||
# Adding in the end if we know we are not done
|
||||
self.output_tokens.append(next_tok)
|
||||
|
||||
if self.request is None or not await self.request.is_disconnected():
|
||||
response = generate_response(
|
||||
self.initial_tokens,
|
||||
self.output_tokens,
|
||||
self.request_body,
|
||||
debug_mode=self.debug_mode,
|
||||
function_call_ids=self.function_call_ids,
|
||||
response_id=self.response_id,
|
||||
previous_response_id=self.request_body.previous_response_id,
|
||||
browser_tool=self.browser_tool,
|
||||
browser_call_ids=self.browser_call_ids,
|
||||
)
|
||||
if self.store_callback and self.request_body.store:
|
||||
self.store_callback(self.response_id, self.request_body, response)
|
||||
yield self._send_event(
|
||||
ResponseCompletedEvent(
|
||||
type="response.completed",
|
||||
response=response,
|
||||
)
|
||||
)
|
||||
|
||||
@app.post("/v1/responses", response_model=ResponseObject)
|
||||
async def generate(body: ResponsesRequest, request: Request):
|
||||
print("request received")
|
||||
|
||||
use_browser_tool = any(
|
||||
getattr(tool, "type", None) == "browser_search"
|
||||
for tool in (body.tools or [])
|
||||
)
|
||||
|
||||
if use_browser_tool:
|
||||
backend = ExaBackend(
|
||||
source="web",
|
||||
)
|
||||
browser_tool = SimpleBrowserTool(backend=backend)
|
||||
else:
|
||||
browser_tool = None
|
||||
|
||||
if body.previous_response_id:
|
||||
prev = responses_store.get(body.previous_response_id)
|
||||
if prev:
|
||||
prev_req, prev_resp = prev
|
||||
|
||||
def _ensure_list(inp):
|
||||
if isinstance(inp, str):
|
||||
return [
|
||||
Item(
|
||||
type="message",
|
||||
role="user",
|
||||
content=[TextContentItem(type="input_text", text=inp)],
|
||||
)
|
||||
]
|
||||
return list(inp)
|
||||
|
||||
merged_input = _ensure_list(prev_req.input) + list(prev_resp.output)
|
||||
merged_input.extend(_ensure_list(body.input))
|
||||
|
||||
if body.instructions is None:
|
||||
body.instructions = prev_req.instructions
|
||||
body.input = merged_input
|
||||
|
||||
|
||||
system_message_content = SystemContent.new().with_conversation_start_date(
|
||||
datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
)
|
||||
|
||||
if body.reasoning is not None:
|
||||
reasoning_effort = get_reasoning_effort(body.reasoning.effort)
|
||||
system_message_content = system_message_content.with_reasoning_effort(reasoning_effort)
|
||||
|
||||
if use_browser_tool:
|
||||
system_message_content = system_message_content.with_tools(browser_tool.tool_config)
|
||||
|
||||
system_message = Message.from_role_and_content(
|
||||
Role.SYSTEM, system_message_content
|
||||
)
|
||||
|
||||
developer_message_content = DeveloperContent.new().with_instructions(
|
||||
body.instructions
|
||||
)
|
||||
|
||||
tools = []
|
||||
if body.tools:
|
||||
for tool in body.tools:
|
||||
if tool.type == "function":
|
||||
has_functions = True
|
||||
tools.append(
|
||||
ToolDescription.new(
|
||||
tool.name,
|
||||
tool.description,
|
||||
tool.parameters,
|
||||
)
|
||||
)
|
||||
|
||||
if len(tools) > 0:
|
||||
developer_message_content = developer_message_content.with_function_tools(
|
||||
tools
|
||||
)
|
||||
|
||||
developer_message = Message.from_role_and_content(
|
||||
Role.DEVELOPER, developer_message_content
|
||||
)
|
||||
|
||||
messages = [system_message, developer_message]
|
||||
|
||||
if isinstance(body.input, str):
|
||||
user_message = Message.from_role_and_content(Role.USER, body.input)
|
||||
messages.append(user_message)
|
||||
else:
|
||||
is_last_message_function_call_output = (
|
||||
len(body.input) > 0 and body.input[-1].type == "function_call_output"
|
||||
)
|
||||
function_call_map = {}
|
||||
# Find the index of the last assistant message
|
||||
last_assistant_idx = -1
|
||||
for idx, item in enumerate(body.input):
|
||||
if item.type == "message" and item.role == Role.ASSISTANT:
|
||||
last_assistant_idx = idx
|
||||
|
||||
for idx, item in enumerate(body.input):
|
||||
if item.type == "message":
|
||||
# TODO: add system prompt handling
|
||||
if isinstance(item.content, str):
|
||||
messages.append(
|
||||
Message.from_role_and_content(item.role, item.content)
|
||||
)
|
||||
else:
|
||||
for content_item in item.content:
|
||||
messages.append(
|
||||
Message.from_role_and_content(item.role, content_item.text)
|
||||
)
|
||||
# add final channel to the last assistant message if it's from the assistant
|
||||
if item.role == Role.ASSISTANT:
|
||||
messages[-1] = messages[-1].with_channel("final")
|
||||
elif item.type == "reasoning":
|
||||
# Only include reasoning if it is after the last assistant message and we are handling a function call at the moment
|
||||
if (
|
||||
idx > last_assistant_idx
|
||||
and is_last_message_function_call_output
|
||||
):
|
||||
for content_item in item.content:
|
||||
messages.append(
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, content_item.text
|
||||
).with_channel("analysis")
|
||||
)
|
||||
elif item.type == "function_call":
|
||||
function_call_map[item.call_id] = item
|
||||
messages.append(
|
||||
Message.from_role_and_content(Role.ASSISTANT, item.arguments)
|
||||
.with_recipient(f"functions.{item.name}")
|
||||
.with_channel("commentary")
|
||||
)
|
||||
elif item.type == "function_call_output":
|
||||
function_call = function_call_map.get(item.call_id, None)
|
||||
if not function_call:
|
||||
raise ValueError(f"Function call {item.call_id} not found")
|
||||
|
||||
messages.append(
|
||||
Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{function_call.name}"),
|
||||
item.output,
|
||||
).with_recipient("assistant").with_channel("commentary")
|
||||
)
|
||||
|
||||
conversation = Conversation.from_messages(messages)
|
||||
|
||||
initial_tokens = encoding.render_conversation_for_completion(
|
||||
conversation, Role.ASSISTANT
|
||||
)
|
||||
print(encoding.decode_utf8(initial_tokens))
|
||||
response_id = f"resp_{uuid.uuid4().hex}"
|
||||
|
||||
def store_callback(rid: str, req: ResponsesRequest, resp: ResponseObject):
|
||||
responses_store[rid] = (req, resp)
|
||||
|
||||
event_stream = StreamResponsesEvents(
|
||||
initial_tokens,
|
||||
body,
|
||||
as_sse=body.stream,
|
||||
request=request,
|
||||
response_id=response_id,
|
||||
store_callback=store_callback,
|
||||
browser_tool=browser_tool,
|
||||
)
|
||||
|
||||
if body.stream:
|
||||
return StreamingResponse(event_stream.run(), media_type="text/event-stream")
|
||||
else:
|
||||
last_event = None
|
||||
async for event in event_stream.run():
|
||||
last_event = event
|
||||
|
||||
return last_event.response
|
||||
|
||||
return app
|
||||
129
gpt_oss/responses_api/events.py
Normal file
129
gpt_oss/responses_api/events.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# torchrun --nproc-per-node=4 responses_api.py
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .types import (
|
||||
FunctionCallItem,
|
||||
Item,
|
||||
ReasoningItem,
|
||||
ResponseObject,
|
||||
TextContentItem,
|
||||
ReasoningTextContentItem,
|
||||
WebSearchCallItem,
|
||||
UrlCitation,
|
||||
)
|
||||
|
||||
|
||||
class ResponseEvent(BaseModel):
|
||||
sequence_number: Optional[int] = 1
|
||||
|
||||
|
||||
class ResponseCreatedEvent(ResponseEvent):
|
||||
type: Literal["response.created"]
|
||||
response: ResponseObject
|
||||
|
||||
|
||||
class ResponseCompletedEvent(ResponseEvent):
|
||||
type: Literal["response.completed"]
|
||||
response: ResponseObject
|
||||
|
||||
|
||||
class ResponseOutputTextDelta(ResponseEvent):
|
||||
type: Literal["response.output_text.delta"] = "response.output_text.delta"
|
||||
item_id: str = "item_1234"
|
||||
output_index: int = 0
|
||||
content_index: int = 0
|
||||
delta: str = ""
|
||||
logprobs: list = []
|
||||
|
||||
|
||||
class ResponseReasoningSummaryTextDelta(ResponseEvent):
|
||||
type: Literal["response.reasoning_summary_text.delta"] = (
|
||||
"response.reasoning_summary_text.delta"
|
||||
)
|
||||
item_id: str = "item_1234"
|
||||
output_index: int = 0
|
||||
content_index: int = 0
|
||||
delta: str = ""
|
||||
|
||||
|
||||
class ResponseReasoningTextDelta(ResponseEvent):
|
||||
type: Literal["response.reasoning_text.delta"] = "response.reasoning_text.delta"
|
||||
item_id: str = "item_1234"
|
||||
output_index: int = 0
|
||||
content_index: int = 0
|
||||
delta: str = ""
|
||||
|
||||
|
||||
class ResponseReasoningTextDone(ResponseEvent):
|
||||
type: Literal["response.reasoning_text.done"] = "response.reasoning_text.done"
|
||||
item_id: str = "item_1234"
|
||||
output_index: int = 0
|
||||
content_index: int = 0
|
||||
text: str = ""
|
||||
|
||||
|
||||
class ResponseOutputItemAdded(ResponseEvent):
|
||||
type: Literal["response.output_item.added"] = "response.output_item.added"
|
||||
output_index: int = 0
|
||||
item: Union[Item, ReasoningItem, FunctionCallItem, WebSearchCallItem]
|
||||
|
||||
|
||||
class ResponseOutputItemDone(ResponseEvent):
|
||||
type: Literal["response.output_item.done"] = "response.output_item.done"
|
||||
output_index: int = 0
|
||||
item: Union[Item, ReasoningItem, FunctionCallItem, WebSearchCallItem]
|
||||
|
||||
|
||||
class ResponseInProgressEvent(ResponseEvent):
|
||||
type: Literal["response.in_progress"]
|
||||
response: ResponseObject
|
||||
|
||||
|
||||
class ResponseContentPartAdded(ResponseEvent):
|
||||
type: Literal["response.content_part.added"] = "response.content_part.added"
|
||||
item_id: str = "item_1234"
|
||||
output_index: int = 0
|
||||
content_index: int = 0
|
||||
part: Union[TextContentItem, ReasoningTextContentItem]
|
||||
|
||||
|
||||
class ResponseOutputTextDone(ResponseEvent):
|
||||
type: Literal["response.output_text.done"] = "response.output_text.done"
|
||||
item_id: str = "item_1234"
|
||||
output_index: int = 0
|
||||
content_index: int = 0
|
||||
text: str = ""
|
||||
logprobs: list = []
|
||||
|
||||
|
||||
class ResponseContentPartDone(ResponseEvent):
|
||||
type: Literal["response.content_part.done"] = "response.content_part.done"
|
||||
item_id: str = "item_1234"
|
||||
output_index: int = 0
|
||||
content_index: int = 0
|
||||
part: Union[TextContentItem, ReasoningTextContentItem]
|
||||
|
||||
class ResponseOutputTextAnnotationAdded(ResponseEvent):
|
||||
type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added"
|
||||
item_id: str = "item_1234"
|
||||
output_index: int = 0
|
||||
content_index: int = 0
|
||||
annotation_index: int = 0
|
||||
annotation: UrlCitation
|
||||
|
||||
class ResponseWebSearchCallInProgress(ResponseEvent):
|
||||
type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress"
|
||||
output_index: int = 0
|
||||
item_id: str = "item_1234"
|
||||
|
||||
class ResponseWebSearchCallSearching(ResponseEvent):
|
||||
type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching"
|
||||
output_index: int = 0
|
||||
item_id: str = "item_1234"
|
||||
|
||||
class ResponseWebSearchCallCompleted(ResponseEvent):
|
||||
type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed"
|
||||
output_index: int = 0
|
||||
item_id: str = "item_1234"
|
||||
0
gpt_oss/responses_api/inference/__init__.py
Normal file
0
gpt_oss/responses_api/inference/__init__.py
Normal file
78
gpt_oss/responses_api/inference/metal.py
Normal file
78
gpt_oss/responses_api/inference/metal.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Metal backend for :mod:`gpt_oss.responses_api`."""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from gpt_oss.metal import Context, Model
|
||||
|
||||
|
||||
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
|
||||
"""Load the Metal model and return an inference function."""
|
||||
|
||||
model = Model(checkpoint)
|
||||
context = Context(model)
|
||||
|
||||
def lcp(cache: list[int], inp: list[int]) -> list[int]:
|
||||
i = 0
|
||||
max_len = min(len(cache), len(inp))
|
||||
while i < max_len and cache[i] == inp[i]:
|
||||
i += 1
|
||||
return cache[:i]
|
||||
|
||||
tokens_so_far = []
|
||||
|
||||
def infer_next_token(
|
||||
tokens: list[int], temperature: float = 0.0, new_request: bool = False
|
||||
) -> int:
|
||||
"""Infer next token using incremental LCP caching when possible."""
|
||||
nonlocal tokens_so_far
|
||||
|
||||
# Fast path: first call or explicitly new request.
|
||||
if new_request or not tokens_so_far:
|
||||
context.reset()
|
||||
for t in tokens:
|
||||
context.append(t)
|
||||
tokens_so_far = tokens.copy()
|
||||
context.process()
|
||||
return int(context.sample(temperature=temperature))
|
||||
|
||||
# Longest common prefix length
|
||||
overlap = lcp(tokens_so_far, tokens)
|
||||
ol = len(overlap)
|
||||
prev_len = len(tokens_so_far)
|
||||
cur_len = len(tokens)
|
||||
|
||||
diverged_midstream = (ol < prev_len) and (
|
||||
ol < cur_len
|
||||
) # mismatch not at the end
|
||||
|
||||
if diverged_midstream:
|
||||
# safest: rebuild
|
||||
context.reset()
|
||||
for t in tokens:
|
||||
context.append(t)
|
||||
tokens_so_far = tokens.copy()
|
||||
context.process()
|
||||
return int(context.sample(temperature=temperature))
|
||||
|
||||
if cur_len > prev_len:
|
||||
# pure extension (good for KV reuse)
|
||||
extension = tokens[prev_len:]
|
||||
for t in extension:
|
||||
context.append(t)
|
||||
tokens_so_far = tokens.copy()
|
||||
context.process()
|
||||
return int(context.sample(temperature=temperature))
|
||||
|
||||
if cur_len < prev_len:
|
||||
# truncation/backspace; easiest correct behavior is rebuild
|
||||
context.reset()
|
||||
for t in tokens:
|
||||
context.append(t)
|
||||
tokens_so_far = tokens.copy()
|
||||
context.process()
|
||||
return int(context.sample(temperature=temperature))
|
||||
|
||||
# cur_len == prev_len and everything matches => no new tokens; just sample.
|
||||
return int(context.sample(temperature=temperature))
|
||||
|
||||
return infer_next_token
|
||||
192
gpt_oss/responses_api/inference/ollama.py
Normal file
192
gpt_oss/responses_api/inference/ollama.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
NOTE: this is a stiched together implementation that uses Ollama for inference. It's primarily used
|
||||
for testing and development. It does not leverage any prompt caching or other optimizations and
|
||||
can therefore be slow between turns.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
import requests
|
||||
|
||||
from openai_harmony import load_harmony_encoding, HarmonyEncodingName
|
||||
|
||||
EOS_TOKEN = 200002 # only used on hard timeout
|
||||
|
||||
# Tunables
|
||||
POLL_INTERVAL_S = 0.01 # 10ms between buffer checks
|
||||
CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call
|
||||
NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS
|
||||
FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS
|
||||
|
||||
# Shared state
|
||||
_token_buffer: list[int] = []
|
||||
_buffer_lock = threading.Lock()
|
||||
_stream_thread: Optional[threading.Thread] = None
|
||||
_stream_done = threading.Event()
|
||||
_stream_error: Optional[Exception] = None
|
||||
_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens
|
||||
_previous_request_tokens: list[int] = []
|
||||
|
||||
def lcp(cache: list[int], inp: list[int]) -> list[int]:
|
||||
i = 0
|
||||
max_len = min(len(cache), len(inp))
|
||||
while i < max_len and cache[i] == inp[i]:
|
||||
i += 1
|
||||
return cache[:i]
|
||||
|
||||
def _now():
|
||||
return time.monotonic()
|
||||
|
||||
def _touch_progress():
|
||||
global _last_progress_ts
|
||||
_last_progress_ts = _now()
|
||||
|
||||
def _reset_stream_state():
|
||||
global _token_buffer, _stream_thread, _stream_error
|
||||
with _buffer_lock:
|
||||
_token_buffer = []
|
||||
_stream_done.clear()
|
||||
_stream_thread = None
|
||||
_stream_error = None
|
||||
_touch_progress()
|
||||
|
||||
def setup_model(checkpoint: str) -> Callable[[list[int], float, bool], int]:
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
model_name = checkpoint
|
||||
|
||||
def _start_stream(token_ids: list[int], temperature: float):
|
||||
prompt_text = encoding.decode(token_ids)
|
||||
def run():
|
||||
nonlocal prompt_text, temperature
|
||||
global _stream_error
|
||||
global _previous_request_tokens
|
||||
|
||||
accum_text = ""
|
||||
last_len = 0 # number of tokens already emitted
|
||||
|
||||
try:
|
||||
url = "http://localhost:11434/api/generate"
|
||||
context = None
|
||||
if len(_previous_request_tokens) > 0:
|
||||
context = _previous_request_tokens
|
||||
# cache_hit = lcp(_previous_request_tokens, token_ids)
|
||||
# if len(cache_hit) > 0:
|
||||
# context = cache_hit
|
||||
# print(f"Cache hit: {encoding.decode(context)}")
|
||||
# prompt_text = encoding.decode(token_ids[len(context):])
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"prompt": prompt_text,
|
||||
"stream": True,
|
||||
"context": context,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
|
||||
with requests.post(url, json=payload, stream=True, timeout=60) as resp:
|
||||
resp.raise_for_status()
|
||||
for line in resp.iter_lines(decode_unicode=True):
|
||||
if not line:
|
||||
continue
|
||||
obj = json.loads(line)
|
||||
|
||||
if isinstance(obj.get("response"), str):
|
||||
accum_text += obj["response"]
|
||||
toks = encoding.encode(accum_text, allowed_special="all")
|
||||
if len(toks) > last_len:
|
||||
new_toks = toks[last_len:]
|
||||
with _buffer_lock:
|
||||
_token_buffer.extend(new_toks)
|
||||
last_len = len(toks)
|
||||
_touch_progress()
|
||||
|
||||
if obj.get("done", False):
|
||||
_token_buffer.append(EOS_TOKEN)
|
||||
last_len = len(toks)
|
||||
_touch_progress()
|
||||
context = obj.get("context")
|
||||
if context and len(context) > 0:
|
||||
_previous_request_tokens = context
|
||||
break
|
||||
|
||||
_stream_done.set()
|
||||
|
||||
except Exception as e:
|
||||
_stream_error = e
|
||||
_stream_done.set()
|
||||
|
||||
t = threading.Thread(target=run, name="ollama-stream", daemon=True)
|
||||
t.start()
|
||||
return t
|
||||
|
||||
def infer_next_token(
|
||||
tokens: list[int], temperature: float = 0.0, new_request: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
- Starts a new Ollama stream on new_request.
|
||||
- Forwards tokens as they arrive.
|
||||
- Only emits EOS_TOKEN if we exceed an inactivity timeout.
|
||||
"""
|
||||
global _stream_thread
|
||||
|
||||
if new_request:
|
||||
_reset_stream_state()
|
||||
_stream_thread = _start_stream(token_ids=tokens, temperature=temperature)
|
||||
# Wait for first byte within FIRST_BYTE_TIMEOUT_S (without emitting EOS early)
|
||||
start = _now()
|
||||
while _now() - start < FIRST_BYTE_TIMEOUT_S:
|
||||
with _buffer_lock:
|
||||
if _token_buffer:
|
||||
tok = _token_buffer.pop(0)
|
||||
_touch_progress()
|
||||
return tok
|
||||
if _stream_error is not None:
|
||||
raise RuntimeError(f"Ollama stream error: {_stream_error!r}")
|
||||
# If Ollama finished instantly with no output, continue loop until timeout
|
||||
time.sleep(POLL_INTERVAL_S)
|
||||
# Hard first-byte timeout -> emit EOS so the server can stop this request
|
||||
return EOS_TOKEN
|
||||
|
||||
if _stream_error is not None:
|
||||
raise RuntimeError(f"Ollama stream error: {_stream_error!r}")
|
||||
|
||||
# Normal path: wait up to CALL_MAX_WAIT_S for a token to arrive
|
||||
wait_start = _now()
|
||||
while _now() - wait_start < CALL_MAX_WAIT_S:
|
||||
with _buffer_lock:
|
||||
if _token_buffer:
|
||||
tok = _token_buffer.pop(0)
|
||||
_touch_progress()
|
||||
return tok
|
||||
# No token yet; if we've been idle too long overall, end with EOS
|
||||
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
|
||||
return EOS_TOKEN
|
||||
time.sleep(POLL_INTERVAL_S)
|
||||
|
||||
# Still no token in this call slice. Do NOT send EOS unless we've timed out.
|
||||
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
|
||||
return EOS_TOKEN
|
||||
|
||||
# Tell caller to call us again; block minimally by returning *nothing new*.
|
||||
# We must return an int; safest is to wait a tiny bit longer for a token.
|
||||
# If still none, keep returning only after short waits. Avoid EOS here.
|
||||
# One more short wait to reduce hot-looping:
|
||||
time.sleep(POLL_INTERVAL_S)
|
||||
with _buffer_lock:
|
||||
if _token_buffer:
|
||||
tok = _token_buffer.pop(0)
|
||||
_touch_progress()
|
||||
return tok
|
||||
|
||||
# As a last resort for this call slice, return EOS only on true inactivity timeout.
|
||||
if _now() - _last_progress_ts > NO_TOKEN_TIMEOUT_S:
|
||||
return EOS_TOKEN
|
||||
|
||||
# If we reach here, we still haven't got a token—ask the caller to call again soon.
|
||||
# Return a harmless token that the server will replace/ignore if your interface supports it.
|
||||
# If your interface does NOT allow a sentinel, keep the short-blocking behavior above.
|
||||
return EOS_TOKEN if False else 0 # replace `0` with a PAD/NOOP token your server ignores
|
||||
|
||||
return infer_next_token
|
||||
142
gpt_oss/responses_api/inference/stub.py
Normal file
142
gpt_oss/responses_api/inference/stub.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
fake_tokens = [
|
||||
200005,
|
||||
35644,
|
||||
200008,
|
||||
23483,
|
||||
316,
|
||||
1199,
|
||||
1114,
|
||||
717,
|
||||
170154,
|
||||
13,
|
||||
200007,
|
||||
200006,
|
||||
173781,
|
||||
200005,
|
||||
35644,
|
||||
316,
|
||||
28,
|
||||
44580,
|
||||
775,
|
||||
170154,
|
||||
464,
|
||||
91,
|
||||
542,
|
||||
141043,
|
||||
91,
|
||||
29,
|
||||
4108,
|
||||
200008,
|
||||
10848,
|
||||
7693,
|
||||
7534,
|
||||
28499,
|
||||
18826,
|
||||
18583,
|
||||
200012,
|
||||
]
|
||||
fake_tokens = [
|
||||
200005,
|
||||
35644,
|
||||
200008,
|
||||
1844,
|
||||
31064,
|
||||
25,
|
||||
392,
|
||||
4827,
|
||||
382,
|
||||
220,
|
||||
17,
|
||||
659,
|
||||
220,
|
||||
17,
|
||||
16842,
|
||||
12295,
|
||||
81645,
|
||||
13,
|
||||
51441,
|
||||
6052,
|
||||
13,
|
||||
200007,
|
||||
200006,
|
||||
173781,
|
||||
200005,
|
||||
17196,
|
||||
200008,
|
||||
17,
|
||||
659,
|
||||
220,
|
||||
17,
|
||||
314,
|
||||
220,
|
||||
19,
|
||||
13,
|
||||
9552,
|
||||
238,
|
||||
242,
|
||||
200002,
|
||||
]
|
||||
# fake_tokens = [200005, 35644, 200008, 976, 1825, 31064, 25, 392, 25216, 29400, 290, 11122, 306, 52768, 2117, 16842, 1416, 1309, 316, 2281, 198, 68, 290, 2208, 11122, 13, 1416, 679, 261, 1114, 717, 170154, 484, 44390, 261, 5100, 1621, 26, 581, 1757, 2005, 198, 75, 480, 483, 5100, 392, 137956, 2117, 11, 13180, 4050, 7801, 4733, 290, 11122, 5377, 484, 290, 1114, 7377, 13, 1416, 1309, 260, 198, 78, 1199, 290, 1114, 4584, 364, 58369, 2421, 717, 170154, 483, 5100, 392, 137956, 2117, 11, 13180, 4050, 200007, 200006, 173781, 200005, 12606, 815, 260, 198, 78, 28, 117673, 3490]
|
||||
# fake_tokens = [
|
||||
# 198,
|
||||
# 200005,
|
||||
# 35644,
|
||||
# 200008,
|
||||
# 23483,
|
||||
# 316,
|
||||
# 1199,
|
||||
# 1114,
|
||||
# 717,
|
||||
# 170154,
|
||||
# 13,
|
||||
# 200007,
|
||||
# 200006,
|
||||
# 173781,
|
||||
# 200005,
|
||||
# 12606,
|
||||
# 815,
|
||||
# 316,
|
||||
# 32455,
|
||||
# 106847,
|
||||
# 316,
|
||||
# 28,
|
||||
# 44580,
|
||||
# 775,
|
||||
# 170154,
|
||||
# 464,
|
||||
# 91,
|
||||
# 542,
|
||||
# 141043,
|
||||
# 91,
|
||||
# 29,
|
||||
# 4108,
|
||||
# 200008,
|
||||
# 10848,
|
||||
# 7693,
|
||||
# 7534,
|
||||
# 28499,
|
||||
# 18826,
|
||||
# 18583,
|
||||
# 200012,
|
||||
# 198,
|
||||
# ]
|
||||
|
||||
token_queue = fake_tokens.copy()
|
||||
|
||||
|
||||
def stub_infer_next_token(
|
||||
tokens: list[int], temperature: float = 0.0, new_request: bool = False
|
||||
) -> int:
|
||||
global token_queue
|
||||
next_tok = token_queue.pop(0)
|
||||
if len(token_queue) == 0:
|
||||
token_queue = fake_tokens.copy()
|
||||
time.sleep(0.1)
|
||||
return next_tok
|
||||
|
||||
|
||||
def setup_model(_checkpoint: str) -> Callable[[list[int], float], int]:
|
||||
return stub_infer_next_token
|
||||
102
gpt_oss/responses_api/inference/triton.py
Normal file
102
gpt_oss/responses_api/inference/triton.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import datetime
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from gpt_oss.triton.model import Cache, ModelConfig, Transformer
|
||||
|
||||
DEFAULT_TEMPERATURE = 0.0
|
||||
CONTEXT = 16_384
|
||||
CONCURRENT_SESSIONS = 1
|
||||
|
||||
rank = int(
|
||||
os.environ.get("RANK", 0)
|
||||
) # set this env var to another value to run on other GPUs
|
||||
|
||||
|
||||
def load_model(checkpoint: str):
|
||||
print(f"[{rank}] loading model...")
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
torch.set_grad_enabled(False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
# Load model
|
||||
model = Transformer.from_checkpoint(checkpoint, device=device)
|
||||
|
||||
print(f"[{rank}] loaded")
|
||||
return model, device
|
||||
|
||||
|
||||
def get_infer_next_token(model, device):
|
||||
caches = [
|
||||
Cache(CONCURRENT_SESSIONS, CONTEXT, model.config.num_key_value_heads)
|
||||
for _ in range(len(model.block))
|
||||
]
|
||||
# offsets = torch.zeros(CONCURRENT_SESSIONS, dtype=torch.int32, device=device) # TBD
|
||||
input_token = torch.zeros(
|
||||
1, dtype=torch.int32, device=device
|
||||
) # add concurrent sessions support
|
||||
tokens_so_far = []
|
||||
|
||||
model.prefill(torch.zeros(1, 4, dtype=torch.int32, device=device), caches)
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
logits = model(input_token[None, :], caches=caches)[0]
|
||||
|
||||
def lcp(cache: list[int], inp: list[int]) -> list[int]:
|
||||
i = 0
|
||||
max_len = min(len(cache), len(inp))
|
||||
while i < max_len and cache[i] == inp[i]:
|
||||
i += 1
|
||||
return cache[:i]
|
||||
|
||||
def sample_next_token(
|
||||
logits: torch.Tensor, temperature: float = DEFAULT_TEMPERATURE
|
||||
) -> int:
|
||||
"""Executed only on rank 0."""
|
||||
if temperature == 0.0:
|
||||
return torch.argmax(logits[-1, :], dim=-1).item()
|
||||
probs = torch.softmax(logits * (1.0 / temperature), dim=-1)
|
||||
return torch.multinomial(probs[-1, :], num_samples=1).item()
|
||||
|
||||
@torch.inference_mode()
|
||||
def infer_next_token(
|
||||
tokens: list[int],
|
||||
temperature: float = DEFAULT_TEMPERATURE,
|
||||
new_request: bool = False,
|
||||
) -> int:
|
||||
nonlocal tokens_so_far
|
||||
tokens_so_far = lcp(tokens_so_far, tokens)
|
||||
for cache in caches:
|
||||
cache.truncate(len(tokens_so_far))
|
||||
all_tokens = tokens # for pdb
|
||||
tokens = tokens[len(tokens_so_far) :]
|
||||
|
||||
if len(tokens) > 1:
|
||||
model.prefill(
|
||||
torch.as_tensor(tokens[:-1], dtype=torch.int32, device=device)[None, :],
|
||||
caches,
|
||||
)
|
||||
|
||||
if len(tokens) == 0:
|
||||
breakpoint()
|
||||
|
||||
input_token[-1] = tokens[-1]
|
||||
graph.replay()
|
||||
|
||||
# decide next token on rank‑0
|
||||
next_tok = sample_next_token(logits, temperature=temperature)
|
||||
|
||||
return next_tok
|
||||
|
||||
return infer_next_token
|
||||
|
||||
|
||||
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
|
||||
model, device = load_model(checkpoint)
|
||||
infer_next_token = get_infer_next_token(model, device)
|
||||
return infer_next_token
|
||||
84
gpt_oss/responses_api/inference/vllm.py
Normal file
84
gpt_oss/responses_api/inference/vllm.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
NOTE: this is not the most efficient way to use vLLM. It's a simple implementation that infers
|
||||
one token at a time to mimic the behavior of the Triton implementation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
# vLLM imports
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import TokensPrompt
|
||||
|
||||
DEFAULT_TEMPERATURE = 0.0
|
||||
TP = os.environ.get("TP", 2)
|
||||
|
||||
def load_model(checkpoint: str):
|
||||
"""
|
||||
Create the vLLM engine. We enable prefix caching so repeated prefixes
|
||||
across calls can reuse KV cache for faster prefill.
|
||||
"""
|
||||
|
||||
llm = LLM(
|
||||
model=checkpoint,
|
||||
tensor_parallel_size=TP, # set >1 if you want TP across GPUs
|
||||
enable_prefix_caching=True, # reuse KV for shared prefixes
|
||||
disable_log_stats=True, # uncomment to quiet logs
|
||||
)
|
||||
|
||||
return llm
|
||||
|
||||
|
||||
def get_infer_next_token(llm: LLM):
|
||||
"""
|
||||
Return a callable with the same shape as your original:
|
||||
infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int
|
||||
|
||||
Implementation detail:
|
||||
- We issue a single-token generation with TokensPrompt(prompt_token_ids=tokens).
|
||||
- vLLM handles sampling (temperature=0 => greedy).
|
||||
- With enable_prefix_caching=True, the shared prefix prefill can be reused
|
||||
across calls that share the same prefix.
|
||||
"""
|
||||
|
||||
# Maintain compatibility with your previous closure signature.
|
||||
def infer_next_token(
|
||||
tokens: List[int],
|
||||
temperature: float = DEFAULT_TEMPERATURE,
|
||||
new_request: bool = False, # kept for interface compatibility; unused here
|
||||
) -> int:
|
||||
if not tokens:
|
||||
raise ValueError("tokens must contain at least one input token id")
|
||||
|
||||
sampling = SamplingParams(
|
||||
temperature=float(temperature),
|
||||
max_tokens=1, # we only want the next token
|
||||
n=1, # single continuation
|
||||
# You can expose/enable more controls here (top_p, top_k, etc.)
|
||||
)
|
||||
|
||||
# Provide token IDs directly (no re-tokenization).
|
||||
outputs = llm.generate(
|
||||
TokensPrompt(prompt_token_ids=tokens),
|
||||
sampling_params=sampling,
|
||||
)
|
||||
|
||||
if not outputs or not outputs[0].outputs:
|
||||
raise RuntimeError("vLLM returned empty outputs")
|
||||
|
||||
gen = outputs[0].outputs[0]
|
||||
if not gen.token_ids:
|
||||
# If the model immediately finished (e.g., EOS), decide how you'd like
|
||||
# to signal that. Here we raise; you could also return an EOS id.
|
||||
raise RuntimeError("No next token was generated (possibly EOS).")
|
||||
|
||||
next_tok = int(gen.token_ids[0])
|
||||
return next_tok
|
||||
|
||||
return infer_next_token
|
||||
|
||||
|
||||
def setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]:
|
||||
llm = load_model(checkpoint)
|
||||
infer_next_token = get_infer_next_token(llm)
|
||||
return infer_next_token
|
||||
56
gpt_oss/responses_api/serve.py
Normal file
56
gpt_oss/responses_api/serve.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# torchrun --nproc-per-node=4 serve.py
|
||||
|
||||
import argparse
|
||||
|
||||
import uvicorn
|
||||
from openai_harmony import (
|
||||
HarmonyEncodingName,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from .api_server import create_api_server
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Responses API server")
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
metavar="FILE",
|
||||
type=str,
|
||||
help="Path to the SafeTensors checkpoint",
|
||||
default="~/model",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
metavar="PORT",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to run the server on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inference-backend",
|
||||
metavar="BACKEND",
|
||||
type=str,
|
||||
help="Inference backend to use",
|
||||
# default to metal on macOS, triton on other platforms
|
||||
default="metal" if __import__("platform").system() == "Darwin" else "triton",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.inference_backend == "triton":
|
||||
from .inference.triton import setup_model
|
||||
elif args.inference_backend == "stub":
|
||||
from .inference.stub import setup_model
|
||||
elif args.inference_backend == "metal":
|
||||
from .inference.metal import setup_model
|
||||
elif args.inference_backend == "ollama":
|
||||
from .inference.ollama import setup_model
|
||||
elif args.inference_backend == "vllm":
|
||||
from .inference.vllm import setup_model
|
||||
else:
|
||||
raise ValueError(f"Invalid inference backend: {args.inference_backend}")
|
||||
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
infer_next_token = setup_model(args.checkpoint)
|
||||
uvicorn.run(create_api_server(infer_next_token, encoding), port=args.port)
|
||||
152
gpt_oss/responses_api/types.py
Normal file
152
gpt_oss/responses_api/types.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
from openai_harmony import ReasoningEffort
|
||||
from pydantic import BaseModel
|
||||
|
||||
MODEL_IDENTIFIER = "gpt-oss-120b"
|
||||
DEFAULT_TEMPERATURE = 0.0
|
||||
REASONING_EFFORT = ReasoningEffort.LOW
|
||||
DEFAULT_MAX_OUTPUT_TOKENS = 10_000
|
||||
|
||||
class UrlCitation(BaseModel):
|
||||
type: Literal["url_citation"]
|
||||
end_index: int
|
||||
start_index: int
|
||||
url: str
|
||||
title: str
|
||||
|
||||
class TextContentItem(BaseModel):
|
||||
type: Union[Literal["text"], Literal["input_text"], Literal["output_text"]]
|
||||
text: str
|
||||
status: Optional[str] = "completed"
|
||||
annotations: Optional[list[UrlCitation]] = None
|
||||
|
||||
|
||||
class SummaryTextContentItem(BaseModel):
|
||||
# using summary for compatibility with the existing API
|
||||
type: Literal["summary_text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ReasoningTextContentItem(BaseModel):
|
||||
type: Literal["reasoning_text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ReasoningItem(BaseModel):
|
||||
id: str = "rs_1234"
|
||||
type: Literal["reasoning"]
|
||||
summary: list[SummaryTextContentItem]
|
||||
content: Optional[list[ReasoningTextContentItem]] = []
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
type: Optional[Literal["message"]] = "message"
|
||||
role: Literal["user", "assistant", "system"]
|
||||
content: Union[list[TextContentItem], str]
|
||||
status: Union[Literal["in_progress", "completed", "incomplete"], None] = None
|
||||
|
||||
|
||||
class FunctionCallItem(BaseModel):
|
||||
type: Literal["function_call"]
|
||||
name: str
|
||||
arguments: str
|
||||
status: Literal["in_progress", "completed", "incomplete"] = "completed"
|
||||
id: str = "fc_1234"
|
||||
call_id: str = "call_1234"
|
||||
|
||||
|
||||
class FunctionCallOutputItem(BaseModel):
|
||||
type: Literal["function_call_output"]
|
||||
call_id: str = "call_1234"
|
||||
output: str
|
||||
|
||||
class WebSearchActionSearch(BaseModel):
|
||||
type: Literal["search"]
|
||||
query: Optional[str] = None
|
||||
|
||||
class WebSearchActionOpenPage(BaseModel):
|
||||
type: Literal["open_page"]
|
||||
url: Optional[str] = None
|
||||
|
||||
class WebSearchActionFind(BaseModel):
|
||||
type: Literal["find"]
|
||||
pattern: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
|
||||
class WebSearchCallItem(BaseModel):
|
||||
type: Literal["web_search_call"]
|
||||
id: str = "ws_1234"
|
||||
status: Literal["in_progress", "completed", "incomplete"] = "completed"
|
||||
action: Union[WebSearchActionSearch, WebSearchActionOpenPage, WebSearchActionFind]
|
||||
|
||||
class Error(BaseModel):
|
||||
code: str
|
||||
message: str
|
||||
|
||||
|
||||
class IncompleteDetails(BaseModel):
|
||||
reason: str
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class FunctionToolDefinition(BaseModel):
|
||||
type: Literal["function"]
|
||||
name: str
|
||||
parameters: dict # this should be typed stricter if you add strict mode
|
||||
strict: bool = False # change this if you support strict mode
|
||||
description: Optional[str] = ""
|
||||
|
||||
|
||||
class BrowserToolConfig(BaseModel):
|
||||
type: Literal["browser_search"]
|
||||
|
||||
|
||||
class ReasoningConfig(BaseModel):
|
||||
effort: Literal["low", "medium", "high"] = REASONING_EFFORT
|
||||
|
||||
|
||||
class ResponsesRequest(BaseModel):
|
||||
instructions: Optional[str] = None
|
||||
max_output_tokens: Optional[int] = DEFAULT_MAX_OUTPUT_TOKENS
|
||||
input: Union[
|
||||
str, list[Union[Item, ReasoningItem, FunctionCallItem, FunctionCallOutputItem, WebSearchCallItem]]
|
||||
]
|
||||
model: Optional[str] = MODEL_IDENTIFIER
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[list[Union[FunctionToolDefinition, BrowserToolConfig]]] = []
|
||||
reasoning: Optional[ReasoningConfig] = ReasoningConfig()
|
||||
metadata: Optional[Dict[str, Any]] = {}
|
||||
tool_choice: Optional[Literal["auto", "none"]] = "auto"
|
||||
parallel_tool_calls: Optional[bool] = False
|
||||
store: Optional[bool] = False
|
||||
previous_response_id: Optional[str] = None
|
||||
temperature: Optional[float] = DEFAULT_TEMPERATURE
|
||||
include: Optional[list[str]] = None
|
||||
|
||||
|
||||
class ResponseObject(BaseModel):
|
||||
output: list[Union[Item, ReasoningItem, FunctionCallItem, FunctionCallOutputItem, WebSearchCallItem]]
|
||||
created_at: int
|
||||
usage: Optional[Usage] = None
|
||||
status: Literal["completed", "failed", "incomplete", "in_progress"] = "in_progress"
|
||||
background: None = None
|
||||
error: Optional[Error] = None
|
||||
incomplete_details: Optional[IncompleteDetails] = None
|
||||
instructions: Optional[str] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
max_tool_calls: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = {}
|
||||
model: Optional[str] = MODEL_IDENTIFIER
|
||||
parallel_tool_calls: Optional[bool] = False
|
||||
previous_response_id: Optional[str] = None
|
||||
id: Optional[str] = "resp_1234"
|
||||
object: Optional[str] = "response"
|
||||
text: Optional[Dict[str, Any]] = None
|
||||
tool_choice: Optional[str] = "auto"
|
||||
top_p: Optional[int] = 1
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user