Add support for rouge NLP metric

This commit is contained in:
Ian Webster
2023-06-10 15:37:24 -07:00
parent d2b093ab1e
commit 0c2f5d4c8c
8 changed files with 100 additions and 3 deletions

View File

@@ -98,9 +98,10 @@ See [Test assertions](https://promptfoo.dev/docs/configuration/expected-outputs)
| `is-json` | output is valid json | | `is-json` | output is valid json |
| `contains-json` | output contains valid json | | `contains-json` | output contains valid json |
| `javascript` | provided Javascript function validates the output | | `javascript` | provided Javascript function validates the output |
| `webhook` | provided webhook returns `{pass: true} | | `webhook` | provided webhook returns `{pass: true}` |
| `similar` | embeddings and cosine similarity are above a threshold | | `similar` | embeddings and cosine similarity are above a threshold |
| `llm-rubric` | LLM output matches a given rubric, using a Language Model to grade output | | `llm-rubric` | LLM output matches a given rubric, using a Language Model to grade output |
| `rouge-n` | Rouge-N score is above a given threshold |
Every test type can be negated by prepending `not-`. For example, `not-equals` or `not-regex`. Every test type can be negated by prepending `not-`. For example, `not-equals` or `not-regex`.

28
package-lock.json generated
View File

@@ -27,6 +27,7 @@
"node-fetch": "^2.6.7", "node-fetch": "^2.6.7",
"nunjucks": "^3.2.4", "nunjucks": "^3.2.4",
"opener": "^1.5.2", "opener": "^1.5.2",
"rouge": "^1.0.3",
"socket.io": "^4.6.1", "socket.io": "^4.6.1",
"tiny-invariant": "^1.3.1", "tiny-invariant": "^1.3.1",
"winston": "^3.8.2" "winston": "^3.8.2"
@@ -3973,6 +3974,12 @@
"signal-exit": "^3.0.2" "signal-exit": "^3.0.2"
} }
}, },
"node_modules/lodash-node": {
"version": "2.4.1",
"resolved": "https://registry.npmjs.org/lodash-node/-/lodash-node-2.4.1.tgz",
"integrity": "sha512-egEt8eNQp2kZWRmngahiqMoDCDCENv3uM188S7Ed5t4k3v6RrLELXC+FqLNMUnhCo7gvQX3G1V8opK/Lcslahg==",
"deprecated": "This package is discontinued. Use lodash@^4.0.0."
},
"node_modules/lodash.clonedeep": { "node_modules/lodash.clonedeep": {
"version": "4.5.0", "version": "4.5.0",
"resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz", "resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz",
@@ -4665,6 +4672,14 @@
"node": ">=10" "node": ">=10"
} }
}, },
"node_modules/rouge": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/rouge/-/rouge-1.0.3.tgz",
"integrity": "sha512-YCt74Dxsi99E8/uh943FTa80EmGboaOu1ij4q8WD4EAGyvyWYaH7MRHorrDbGgLY7iFUwDwyW/g9KJZx7D5fUQ==",
"dependencies": {
"lodash-node": "^2.4.1"
}
},
"node_modules/safe-buffer": { "node_modules/safe-buffer": {
"version": "5.2.1", "version": "5.2.1",
"resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz",
@@ -8647,6 +8662,11 @@
"signal-exit": "^3.0.2" "signal-exit": "^3.0.2"
} }
}, },
"lodash-node": {
"version": "2.4.1",
"resolved": "https://registry.npmjs.org/lodash-node/-/lodash-node-2.4.1.tgz",
"integrity": "sha512-egEt8eNQp2kZWRmngahiqMoDCDCENv3uM188S7Ed5t4k3v6RrLELXC+FqLNMUnhCo7gvQX3G1V8opK/Lcslahg=="
},
"lodash.clonedeep": { "lodash.clonedeep": {
"version": "4.5.0", "version": "4.5.0",
"resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz", "resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz",
@@ -9130,6 +9150,14 @@
"integrity": "sha512-X2UW6Nw3n/aMgDVy+0rSqgHlv39WZAlZrXCdnbyEiKm17DSqHX4MmQMaST3FbeWR5FTuRcUwYAziZajji0Y7mg==", "integrity": "sha512-X2UW6Nw3n/aMgDVy+0rSqgHlv39WZAlZrXCdnbyEiKm17DSqHX4MmQMaST3FbeWR5FTuRcUwYAziZajji0Y7mg==",
"dev": true "dev": true
}, },
"rouge": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/rouge/-/rouge-1.0.3.tgz",
"integrity": "sha512-YCt74Dxsi99E8/uh943FTa80EmGboaOu1ij4q8WD4EAGyvyWYaH7MRHorrDbGgLY7iFUwDwyW/g9KJZx7D5fUQ==",
"requires": {
"lodash-node": "^2.4.1"
}
},
"safe-buffer": { "safe-buffer": {
"version": "5.2.1", "version": "5.2.1",
"resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz",

View File

@@ -77,6 +77,7 @@
"node-fetch": "^2.6.7", "node-fetch": "^2.6.7",
"nunjucks": "^3.2.4", "nunjucks": "^3.2.4",
"opener": "^1.5.2", "opener": "^1.5.2",
"rouge": "^1.0.3",
"socket.io": "^4.6.1", "socket.io": "^4.6.1",
"tiny-invariant": "^1.3.1", "tiny-invariant": "^1.3.1",
"winston": "^3.8.2" "winston": "^3.8.2"

View File

@@ -1,3 +1,4 @@
import rouge from 'rouge';
import invariant from 'tiny-invariant'; import invariant from 'tiny-invariant';
import nunjucks from 'nunjucks'; import nunjucks from 'nunjucks';
@@ -18,6 +19,31 @@ const SIMILAR_REGEX = /similar(?::|\((\d+(\.\d+)?)\):)/;
const DEFAULT_SEMANTIC_SIMILARITY_THRESHOLD = 0.8; const DEFAULT_SEMANTIC_SIMILARITY_THRESHOLD = 0.8;
function handleRougeScore(
baseType: 'rouge-n',
assertion: Assertion,
expected: string | string[],
output: string,
inverted: boolean,
): GradingResult {
const fnName = baseType[baseType.length - 1] as 'n' | 'l' | 's';
const rougeMethod = rouge[fnName];
const score = rougeMethod(output, expected);
console.log(output, expected, score);
const pass = score >= (assertion.threshold || 0.75) != inverted;
return {
pass,
reason: pass
? `${baseType.toUpperCase()} score ${score} is greater than or equal to threshold ${
assertion.threshold || 0.75
}`
: `${baseType.toUpperCase()} score ${score} is less than threshold ${
assertion.threshold || 0.75
}`,
};
}
export async function runAssertions(test: AtomicTestCase, output: string): Promise<GradingResult> { export async function runAssertions(test: AtomicTestCase, output: string): Promise<GradingResult> {
const tokensUsed = { const tokensUsed = {
total: 0, total: 0,
@@ -248,6 +274,11 @@ ${assertion.value}`,
}; };
} }
if (baseType === 'rouge-n') {
invariant(assertion.value, '"rouge" assertion type must a value (string or string array)');
return handleRougeScore(baseType, assertion, assertion.value, output, inverse);
}
throw new Error('Unknown assertion type: ' + assertion.type); throw new Error('Unknown assertion type: ' + assertion.type);
} }

View File

@@ -125,7 +125,10 @@ type BaseAssertionTypes =
| 'javascript' | 'javascript'
| 'similar' | 'similar'
| 'llm-rubric' | 'llm-rubric'
| 'webhook'; | 'webhook'
| 'rouge-n'
| 'rouge-s'
| 'rouge-l';
type NotPrefixed<T extends string> = `not-${T}`; type NotPrefixed<T extends string> = `not-${T}`;

View File

@@ -452,6 +452,29 @@ describe('runAssertion', () => {
expect(result.pass).toBeFalsy(); expect(result.pass).toBeFalsy();
expect(result.reason).toBe('Webhook error: Webhook response status: 500'); expect(result.reason).toBe('Webhook error: Webhook response status: 500');
}); });
// Test for rouge-n assertion
const rougeNAssertion: Assertion = {
type: 'rouge-n',
value: 'This is the expected output.',
threshold: 0.75,
};
it('should pass when the rouge-n assertion passes', async () => {
const output = 'This is the expected output.';
const result: GradingResult = await runAssertion(rougeNAssertion, {} as AtomicTestCase, output);
expect(result.pass).toBeTruthy();
expect(result.reason).toBe('ROUGE-N score 1 is greater than or equal to threshold 0.75');
});
it('should fail when the rouge-n assertion fails', async () => {
const output = 'some different output';
const result: GradingResult = await runAssertion(rougeNAssertion, {} as AtomicTestCase, output);
expect(result.pass).toBeFalsy();
expect(result.reason).toBe('ROUGE-N score 0.2 is less than threshold 0.75');
});
}); });
describe('assertionFromString', () => { describe('assertionFromString', () => {

View File

@@ -11,6 +11,6 @@
"sourceMap": true, "sourceMap": true,
"strict": true "strict": true
}, },
"include": ["src/"], "include": ["src/", "typings/**/*"],
"exclude": ["node_modules", "dist", "src/web/client/**/*"] "exclude": ["node_modules", "dist", "src/web/client/**/*"]
} }

10
typings/rouge.d.ts vendored Normal file
View File

@@ -0,0 +1,10 @@
declare module 'rouge' {
function n(
candidate: string,
reference: string | string[],
n?: number,
jackknife?: boolean,
): number;
function l(candidate: string, reference: string | string[], jackknife?: boolean): number;
function s(candidate: string, reference: string | string[], jackknife?: boolean): number;
}