|
|
|
|
Quelle transformers-dev.js
Sprache: JAVA
|
|
import * as __WEBPACK_EXTERNAL_MODULE_chrome_global_content_ml_ort_webgpu_dev_mjs_a 2210ba4__ from "chrome://global/content/ml/ort.webgpu-dev.mjs";
/******/ var __webpack_modules__ = ({
/***/ "#onnxruntime-webgpu":
/*!****************************************************************!*\
!*** external "chrome://global/content/ml/ort.webgpu-dev.mjs" ***!
\****************************************************************/
/***/ ((module) => {
module.exports = __WEBPACK_EXTERNAL_MODULE_chrome_global_content_ml_ort_webgpu_dev_mjs_a2210ba4__;
/***/ }),
/***/ "?7a2c":
/*!********************!*\
!*** fs (ignored) ***!
\********************/
/***/ (() => {
/* (ignored) */
/***/ }),
/***/ "?a42a":
/*!**********************!*\
!*** path (ignored) ***!
\**********************/
/***/ (() => {
/* (ignored) */
/***/ }),
/***/ "?2b25":
/*!***********************!*\
!*** sharp (ignored) ***!
\***********************/
/***/ (() => {
/* (ignored) */
/***/ }),
/***/ "?569f":
/*!********************!*\
!*** fs (ignored) ***!
\********************/
/***/ (() => {
/* (ignored) */
/***/ }),
/***/ "?3f59":
/*!**********************!*\
!*** path (ignored) ***!
\**********************/
/***/ (() => {
/* (ignored) */
/***/ }),
/***/ "?154a":
/*!*********************!*\
!*** url (ignored) ***!
\*********************/
/***/ (() => {
/* (ignored) */
/***/ }),
/***/ "./node_modules/@huggingface/jinja/dist/index.js":
/*!*******************************************************!*\
!*** ./node_modules/@huggingface/jinja/dist/index.js ***!
\*******************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ Environment: () => (/* binding */ Environment),
/* harmony export */ Interpreter: () => (/* binding */ Interpreter),
/* harmony export */ Template: () => (/* binding */ Template),
/* harmony export */ parse: () => (/* binding */ parse),
/* harmony export */ tokenize: () => (/* binding */ tokenize)
/* harmony export */ });
// src/lexer.ts
var TOKEN_TYPES = Object.freeze({
Text: "Text",
// The text between Jinja statements or expressions
NumericLiteral: "NumericLiteral",
// e.g., 123
BooleanLiteral: "BooleanLiteral",
// true or false
NullLiteral: "NullLiteral",
// none
StringLiteral: "StringLiteral",
// 'string'
Identifier: "Identifier",
// Variables, functions, etc.
Equals: "Equals",
// =
OpenParen: "OpenParen",
// (
CloseParen: "CloseParen",
// )
OpenStatement: "OpenStatement",
// {%
CloseStatement: "CloseStatement",
// %}
OpenExpression: "OpenExpression",
// {{
CloseExpression: "CloseExpression",
// }}
OpenSquareBracket: "OpenSquareBracket",
// [
CloseSquareBracket: "CloseSquareBracket",
// ]
OpenCurlyBracket: "OpenCurlyBracket",
// {
CloseCurlyBracket: "CloseCurlyBracket",
// }
Comma: "Comma",
// ,
Dot: "Dot",
// .
Colon: "Colon",
// :
Pipe: "Pipe",
// |
CallOperator: "CallOperator",
// ()
AdditiveBinaryOperator: "AdditiveBinaryOperator",
// + -
MultiplicativeBinaryOperator: "MultiplicativeBinaryOperator",
// * / %
ComparisonBinaryOperator: "ComparisonBinaryOperator",
// < > <= >= == !=
UnaryOperator: "UnaryOperator",
// ! - +
// Keywords
Set: "Set",
If: "If",
For: "For",
In: "In",
Is: "Is",
NotIn: "NotIn",
Else: "Else",
EndIf: "EndIf",
ElseIf: "ElseIf",
EndFor: "EndFor",
And: "And",
Or: "Or",
Not: "UnaryOperator",
Macro: "Macro",
EndMacro: "EndMacro"
});
var KEYWORDS = Object.freeze({
set: TOKEN_TYPES.Set,
for: TOKEN_TYPES.For,
in: TOKEN_TYPES.In,
is: TOKEN_TYPES.Is,
if: TOKEN_TYPES.If,
else: TOKEN_TYPES.Else,
endif: TOKEN_TYPES.EndIf,
elif: TOKEN_TYPES.ElseIf,
endfor: TOKEN_TYPES.EndFor,
and: TOKEN_TYPES.And,
or: TOKEN_TYPES.Or,
not: TOKEN_TYPES.Not,
"not in": TOKEN_TYPES.NotIn,
macro: TOKEN_TYPES.Macro,
endmacro: TOKEN_TYPES.EndMacro,
// Literals
true: TOKEN_TYPES.BooleanLiteral,
false: TOKEN_TYPES.BooleanLiteral,
none: TOKEN_TYPES.NullLiteral,
// NOTE: According to the Jinja docs: The special constants true, false, and none are indeed lowercase.
// Because that caused confusion in the past, (True used to expand to an undefined variable that was considered false),
// all three can now also be written in title case (True, False, and None). However, for consistency, (all Jinja identifiers are lowercase)
// you should use the lowercase versions.
True: TOKEN_TYPES.BooleanLiteral,
False: TOKEN_TYPES.BooleanLiteral,
None: TOKEN_TYPES.NullLiteral
});
var Token = class {
/**
* Constructs a new Token.
* @param {string} value The raw value as seen inside the source code.
* @param {TokenType} type The type of token.
*/
constructor(value, type) {
this.value = value;
this.type = type;
}
};
function isWord(char) {
return /\w/.test(char);
}
function isInteger(char) {
return /[0-9]/.test(char);
}
var ORDERED_MAPPING_TABLE = [
// Control sequences
["{%", TOKEN_TYPES.OpenStatement],
["%}", TOKEN_TYPES.CloseStatement],
["{{", TOKEN_TYPES.OpenExpression],
["}}", TOKEN_TYPES.CloseExpression],
// Single character tokens
["(", TOKEN_TYPES.OpenParen],
[")", TOKEN_TYPES.CloseParen],
["{", TOKEN_TYPES.OpenCurlyBracket],
["}", TOKEN_TYPES.CloseCurlyBracket],
["[", TOKEN_TYPES.OpenSquareBracket],
["]", TOKEN_TYPES.CloseSquareBracket],
[",", TOKEN_TYPES.Comma],
[".", TOKEN_TYPES.Dot],
[":", TOKEN_TYPES.Colon],
["|", TOKEN_TYPES.Pipe],
// Comparison operators
["<=", TOKEN_TYPES.ComparisonBinaryOperator],
[">=", TOKEN_TYPES.ComparisonBinaryOperator],
["==", TOKEN_TYPES.ComparisonBinaryOperator],
["!=", TOKEN_TYPES.ComparisonBinaryOperator],
["<", TOKEN_TYPES.ComparisonBinaryOperator],
[">", TOKEN_TYPES.ComparisonBinaryOperator],
// Arithmetic operators
["+", TOKEN_TYPES.AdditiveBinaryOperator],
["-", TOKEN_TYPES.AdditiveBinaryOperator],
["*", TOKEN_TYPES.MultiplicativeBinaryOperator],
["/", TOKEN_TYPES.MultiplicativeBinaryOperator],
["%", TOKEN_TYPES.MultiplicativeBinaryOperator],
// Assignment operator
["=", TOKEN_TYPES.Equals]
];
var ESCAPE_CHARACTERS = /* @__PURE__ */ new Map([
["n", "\n"],
// New line
["t", " "],
// Horizontal tab
["r", "\r"],
// Carriage return
["b", "\b"],
// Backspace
["f", "\f"],
// Form feed
["v", "\v"],
// Vertical tab
["'", "'"],
// Single quote
['"', '"'],
// Double quote
["\\", "\\"]
// Backslash
]);
function preprocess(template, options = {}) {
if (template.endsWith("\n")) {
template = template.slice(0, -1);
}
template = template.replace(/{#.*?#}/gs, "{##}");
if (options.lstrip_blocks) {
template = template.replace(/^[ \t]*({[#%])/gm, "$1");
}
if (options.trim_blocks) {
template = template.replace(/([#%]})\n/g, "$1");
}
return template.replace(/{##}/g, "").replace(/-%}\s*/g, "%}").replace(/\s*{%-/g, "{%").replace(/-}}\s*/g, "}}").replace(/\s*{{-/g, "{{");
}
function tokenize(source, options = {}) {
const tokens = [];
const src = preprocess(source, options);
let cursorPosition = 0;
const consumeWhile = (predicate) => {
let str = "";
while (predicate(src[cursorPosition])) {
if (src[cursorPosition] === "\\") {
++cursorPosition;
if (cursorPosition >= src.length)
throw new SyntaxError("Unexpected end of input");
const escaped = src[cursorPosition++];
const unescaped = ESCAPE_CHARACTERS.get(escaped);
if (unescaped === void 0) {
throw new SyntaxError(`Unexpected escaped character: ${escaped}`);
}
str += unescaped;
continue;
}
str += src[cursorPosition++];
if (cursorPosition >= src.length)
throw new SyntaxError("Unexpected end of input");
}
return str;
};
main:
while (cursorPosition < src.length) {
const lastTokenType = tokens.at(-1)?.type;
if (lastTokenType === void 0 || lastTokenType === TOKEN_TYPES.CloseStatement || lastTokenType === TOKEN_TYPES.CloseExpression) {
let text = "";
while (cursorPosition < src.length && // Keep going until we hit the next Jinja statement or expression
!(src[cursorPosition] === "{" && (src[cursorPosition + 1] === "%" || src[cursorPosition + 1] === "{"))) {
text += src[cursorPosition++];
}
if (text.length > 0) {
tokens.push(new Token(text, TOKEN_TYPES.Text));
continue;
}
}
consumeWhile((char2) => /\s/.test(char2));
const char = src[cursorPosition];
if (char === "-" || char === "+") {
const lastTokenType2 = tokens.at(-1)?.type;
if (lastTokenType2 === TOKEN_TYPES.Text || lastTokenType2 === void 0) {
throw new SyntaxError(`Unexpected character: ${char}`);
}
switch (lastTokenType2) {
case TOKEN_TYPES.Identifier:
case TOKEN_TYPES.NumericLiteral:
case TOKEN_TYPES.BooleanLiteral:
case TOKEN_TYPES.NullLiteral:
case TOKEN_TYPES.StringLiteral:
case TOKEN_TYPES.CloseParen:
case TOKEN_TYPES.CloseSquareBracket:
break;
default: {
++cursorPosition;
const num = consumeWhile(isInteger);
tokens.push(
new Token(`${char}${num}`, num.length > 0 ? TOKEN_TYPES.NumericLiteral : TOKEN_TYPES.UnaryOperator)
);
continue;
}
}
}
for (const [char2, token] of ORDERED_MAPPING_TABLE) {
const slice2 = src.slice(cursorPosition, cursorPosition + char2.length);
if (slice2 === char2) {
tokens.push(new Token(char2, token));
cursorPosition += char2.length;
continue main;
}
}
if (char === "'" || char === '"') {
++cursorPosition;
const str = consumeWhile((c) => c !== char);
tokens.push(new Token(str, TOKEN_TYPES.StringLiteral));
++cursorPosition;
continue;
}
if (isInteger(char)) {
const num = consumeWhile(isInteger);
tokens.push(new Token(num, TOKEN_TYPES.NumericLiteral));
continue;
}
if (isWord(char)) {
const word = consumeWhile(isWord);
const type = Object.hasOwn(KEYWORDS, word) ? KEYWORDS[word] : TOKEN_TYPES.Identifier;
if (type === TOKEN_TYPES.In && tokens.at(-1)?.type === TOKEN_TYPES.Not) {
tokens.pop();
tokens.push(new Token("not in", TOKEN_TYPES.NotIn));
} else {
tokens.push(new Token(word, type));
}
continue;
}
throw new SyntaxError(`Unexpected character: ${char}`);
}
return tokens;
}
// src/ast.ts
var Statement = class {
type = "Statement";
};
var Program = class extends Statement {
constructor(body) {
super();
this.body = body;
}
type = "Program";
};
var If = class extends Statement {
constructor(test, body, alternate) {
super();
this.test = test;
this.body = body;
this.alternate = alternate;
}
type = "If";
};
var For = class extends Statement {
constructor(loopvar, iterable, body, defaultBlock) {
super();
this.loopvar = loopvar;
this.iterable = iterable;
this.body = body;
this.defaultBlock = defaultBlock;
}
type = "For";
};
var SetStatement = class extends Statement {
constructor(assignee, value) {
super();
this.assignee = assignee;
this.value = value;
}
type = "Set";
};
var Macro = class extends Statement {
constructor(name, args, body) {
super();
this.name = name;
this.args = args;
this.body = body;
}
type = "Macro";
};
var Expression = class extends Statement {
type = "Expression";
};
var MemberExpression = class extends Expression {
constructor(object, property, computed) {
super();
this.object = object;
this.property = property;
this.computed = computed;
}
type = "MemberExpression";
};
var CallExpression = class extends Expression {
constructor(callee, args) {
super();
this.callee = callee;
this.args = args;
}
type = "CallExpression";
};
var Identifier = class extends Expression {
/**
* @param {string} value The name of the identifier
*/
constructor(value) {
super();
this.value = value;
}
type = "Identifier";
};
var Literal = class extends Expression {
constructor(value) {
super();
this.value = value;
}
type = "Literal";
};
var NumericLiteral = class extends Literal {
type = "NumericLiteral";
};
var StringLiteral = class extends Literal {
type = "StringLiteral";
};
var BooleanLiteral = class extends Literal {
type = "BooleanLiteral";
};
var NullLiteral = class extends Literal {
type = "NullLiteral";
};
var ArrayLiteral = class extends Literal {
type = "ArrayLiteral";
};
var TupleLiteral = class extends Literal {
type = "TupleLiteral";
};
var ObjectLiteral = class extends Literal {
type = "ObjectLiteral";
};
var BinaryExpression = class extends Expression {
constructor(operator, left, right) {
super();
this.operator = operator;
this.left = left;
this.right = right;
}
type = "BinaryExpression";
};
var FilterExpression = class extends Expression {
constructor(operand, filter) {
super();
this.operand = operand;
this.filter = filter;
}
type = "FilterExpression";
};
var SelectExpression = class extends Expression {
constructor(iterable, test) {
super();
this.iterable = iterable;
this.test = test;
}
type = "SelectExpression";
};
var TestExpression = class extends Expression {
constructor(operand, negate, test) {
super();
this.operand = operand;
this.negate = negate;
this.test = test;
}
type = "TestExpression";
};
var UnaryExpression = class extends Expression {
constructor(operator, argument) {
super();
this.operator = operator;
this.argument = argument;
}
type = "UnaryExpression";
};
var SliceExpression = class extends Expression {
constructor(start = void 0, stop = void 0, step = void 0) {
super();
this.start = start;
this.stop = stop;
this.step = step;
}
type = "SliceExpression";
};
var KeywordArgumentExpression = class extends Expression {
constructor(key, value) {
super();
this.key = key;
this.value = value;
}
type = "KeywordArgumentExpression";
};
// src/parser.ts
function parse(tokens) {
const program = new Program([]);
let current = 0;
function expect(type, error) {
const prev = tokens[current++];
if (!prev || prev.type !== type) {
throw new Error(`Parser Error: ${error}. ${prev.type} !== ${type}.`);
}
return prev;
}
function parseAny() {
switch (tokens[current].type) {
case TOKEN_TYPES.Text:
return parseText();
case TOKEN_TYPES.OpenStatement:
return parseJinjaStatement();
case TOKEN_TYPES.OpenExpression:
return parseJinjaExpression();
default:
throw new SyntaxError(`Unexpected token type: ${tokens[current].type}`);
}
}
function not(...types) {
return current + types.length <= tokens.length && types.some((type, i) => type !== tokens[current + i].type);
}
function is(...types) {
return current + types.length <= tokens.length && types.every((type, i) => type === tokens[current + i].type);
}
function parseText() {
return new StringLiteral(expect(TOKEN_TYPES.Text, "Expected text token").value);
}
function parseJinjaStatement() {
expect(TOKEN_TYPES.OpenStatement, "Expected opening statement token");
let result;
switch (tokens[current].type) {
case TOKEN_TYPES.Set:
++current;
result = parseSetStatement();
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
break;
case TOKEN_TYPES.If:
++current;
result = parseIfStatement();
expect(TOKEN_TYPES.OpenStatement, "Expected {% token");
expect(TOKEN_TYPES.EndIf, "Expected endif token");
expect(TOKEN_TYPES.CloseStatement, "Expected %} token");
break;
case TOKEN_TYPES.Macro:
++current;
result = parseMacroStatement();
expect(TOKEN_TYPES.OpenStatement, "Expected {% token");
expect(TOKEN_TYPES.EndMacro, "Expected endmacro token");
expect(TOKEN_TYPES.CloseStatement, "Expected %} token");
break;
case TOKEN_TYPES.For:
++current;
result = parseForStatement();
expect(TOKEN_TYPES.OpenStatement, "Expected {% token");
expect(TOKEN_TYPES.EndFor, "Expected endfor token");
expect(TOKEN_TYPES.CloseStatement, "Expected %} token");
break;
default:
throw new SyntaxError(`Unknown statement type: ${tokens[current].type}`);
}
return result;
}
function parseJinjaExpression() {
expect(TOKEN_TYPES.OpenExpression, "Expected opening expression token");
const result = parseExpression();
expect(TOKEN_TYPES.CloseExpression, "Expected closing expression token");
return result;
}
function parseSetStatement() {
const left = parseExpression();
if (is(TOKEN_TYPES.Equals)) {
++current;
const value = parseSetStatement();
return new SetStatement(left, value);
}
return left;
}
function parseIfStatement() {
const test = parseExpression();
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
const body = [];
const alternate = [];
while (!(tokens[current]?.type === TOKEN_TYPES.OpenStatement && (tokens[current + 1]?.type === TOKEN_TYPES.ElseIf || tokens[current + 1]?.type === TOKEN_TYPES.Else || tokens[current + 1]?.type === TOKEN_TYPES.EndIf))) {
body.push(parseAny());
}
if (tokens[current]?.type === TOKEN_TYPES.OpenStatement && tokens[current + 1]?.type !== TOKEN_TYPES.EndIf) {
++current;
if (is(TOKEN_TYPES.ElseIf)) {
expect(TOKEN_TYPES.ElseIf, "Expected elseif token");
alternate.push(parseIfStatement());
} else {
expect(TOKEN_TYPES.Else, "Expected else token");
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
while (!(tokens[current]?.type === TOKEN_TYPES.OpenStatement && tokens[current + 1]?.type === TOKEN_TYPES.EndIf)) {
alternate.push(parseAny());
}
}
}
return new If(test, body, alternate);
}
function parseMacroStatement() {
const name = parsePrimaryExpression();
if (name.type !== "Identifier") {
throw new SyntaxError(`Expected identifier following macro statement`);
}
const args = parseArgs();
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
const body = [];
while (not(TOKEN_TYPES.OpenStatement, TOKEN_TYPES.EndMacro)) {
body.push(parseAny());
}
return new Macro(name, args, body);
}
function parseExpressionSequence(primary = false) {
const fn = primary ? parsePrimaryExpression : parseExpression;
const expressions = [fn()];
const isTuple = is(TOKEN_TYPES.Comma);
while (isTuple) {
++current;
expressions.push(fn());
if (!is(TOKEN_TYPES.Comma)) {
break;
}
}
return isTuple ? new TupleLiteral(expressions) : expressions[0];
}
function parseForStatement() {
const loopVariable = parseExpressionSequence(true);
if (!(loopVariable instanceof Identifier || loopVariable instanceof TupleLiteral)) {
throw new SyntaxError(`Expected identifier/tuple for the loop variable, got ${loopVariable.type} instead`);
}
expect(TOKEN_TYPES.In, "Expected `in` keyword following loop variable");
const iterable = parseExpression();
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
const body = [];
while (not(TOKEN_TYPES.OpenStatement, TOKEN_TYPES.EndFor) && not(TOKEN_TYPES.OpenStatement, TOKEN_TYPES.Else)) {
body.push(parseAny());
}
const alternative = [];
if (is(TOKEN_TYPES.OpenStatement, TOKEN_TYPES.Else)) {
++current;
++current;
expect(TOKEN_TYPES.CloseStatement, "Expected closing statement token");
while (not(TOKEN_TYPES.OpenStatement, TOKEN_TYPES.EndFor)) {
alternative.push(parseAny());
}
}
return new For(loopVariable, iterable, body, alternative);
}
function parseExpression() {
return parseIfExpression();
}
function parseIfExpression() {
const a = parseLogicalOrExpression();
if (is(TOKEN_TYPES.If)) {
++current;
const predicate = parseLogicalOrExpression();
if (is(TOKEN_TYPES.Else)) {
++current;
const b = parseLogicalOrExpression();
return new If(predicate, [a], [b]);
} else {
return new SelectExpression(a, predicate);
}
}
return a;
}
function parseLogicalOrExpression() {
let left = parseLogicalAndExpression();
while (is(TOKEN_TYPES.Or)) {
const operator = tokens[current];
++current;
const right = parseLogicalAndExpression();
left = new BinaryExpression(operator, left, right);
}
return left;
}
function parseLogicalAndExpression() {
let left = parseLogicalNegationExpression();
while (is(TOKEN_TYPES.And)) {
const operator = tokens[current];
++current;
const right = parseLogicalNegationExpression();
left = new BinaryExpression(operator, left, right);
}
return left;
}
function parseLogicalNegationExpression() {
let right;
while (is(TOKEN_TYPES.Not)) {
const operator = tokens[current];
++current;
const arg = parseLogicalNegationExpression();
right = new UnaryExpression(operator, arg);
}
return right ?? parseComparisonExpression();
}
function parseComparisonExpression() {
let left = parseAdditiveExpression();
while (is(TOKEN_TYPES.ComparisonBinaryOperator) || is(TOKEN_TYPES.In) || is(TOKEN_TYPES.NotIn)) {
const operator = tokens[current];
++current;
const right = parseAdditiveExpression();
left = new BinaryExpression(operator, left, right);
}
return left;
}
function parseAdditiveExpression() {
let left = parseMultiplicativeExpression();
while (is(TOKEN_TYPES.AdditiveBinaryOperator)) {
const operator = tokens[current];
++current;
const right = parseMultiplicativeExpression();
left = new BinaryExpression(operator, left, right);
}
return left;
}
function parseCallMemberExpression() {
const member = parseMemberExpression();
if (is(TOKEN_TYPES.OpenParen)) {
return parseCallExpression(member);
}
return member;
}
function parseCallExpression(callee) {
let callExpression = new CallExpression(callee, parseArgs());
if (is(TOKEN_TYPES.OpenParen)) {
callExpression = parseCallExpression(callExpression);
}
return callExpression;
}
function parseArgs() {
expect(TOKEN_TYPES.OpenParen, "Expected opening parenthesis for arguments list");
const args = parseArgumentsList();
expect(TOKEN_TYPES.CloseParen, "Expected closing parenthesis for arguments list");
return args;
}
function parseArgumentsList() {
const args = [];
while (!is(TOKEN_TYPES.CloseParen)) {
let argument = parseExpression();
if (is(TOKEN_TYPES.Equals)) {
++current;
if (!(argument instanceof Identifier)) {
throw new SyntaxError(`Expected identifier for keyword argument`);
}
const value = parseExpression();
argument = new KeywordArgumentExpression(argument, value);
}
args.push(argument);
if (is(TOKEN_TYPES.Comma)) {
++current;
}
}
return args;
}
function parseMemberExpressionArgumentsList() {
const slices = [];
let isSlice = false;
while (!is(TOKEN_TYPES.CloseSquareBracket)) {
if (is(TOKEN_TYPES.Colon)) {
slices.push(void 0);
++current;
isSlice = true;
} else {
slices.push(parseExpression());
if (is(TOKEN_TYPES.Colon)) {
++current;
isSlice = true;
}
}
}
if (slices.length === 0) {
throw new SyntaxError(`Expected at least one argument for member/slice expression`);
}
if (isSlice) {
if (slices.length > 3) {
throw new SyntaxError(`Expected 0-3 arguments for slice expression`);
}
return new SliceExpression(...slices);
}
return slices[0];
}
function parseMemberExpression() {
let object = parsePrimaryExpression();
while (is(TOKEN_TYPES.Dot) || is(TOKEN_TYPES.OpenSquareBracket)) {
const operator = tokens[current];
++current;
let property;
const computed = operator.type !== TOKEN_TYPES.Dot;
if (computed) {
property = parseMemberExpressionArgumentsList();
expect(TOKEN_TYPES.CloseSquareBracket, "Expected closing square bracket");
} else {
property = parsePrimaryExpression();
if (property.type !== "Identifier") {
throw new SyntaxError(`Expected identifier following dot operator`);
}
}
object = new MemberExpression(object, property, computed);
}
return object;
}
function parseMultiplicativeExpression() {
let left = parseTestExpression();
while (is(TOKEN_TYPES.MultiplicativeBinaryOperator)) {
const operator = tokens[current];
++current;
const right = parseTestExpression();
left = new BinaryExpression(operator, left, right);
}
return left;
}
function parseTestExpression() {
let operand = parseFilterExpression();
while (is(TOKEN_TYPES.Is)) {
++current;
const negate = is(TOKEN_TYPES.Not);
if (negate) {
++current;
}
let filter = parsePrimaryExpression();
if (filter instanceof BooleanLiteral) {
filter = new Identifier(filter.value.toString());
} else if (filter instanceof NullLiteral) {
filter = new Identifier("none");
}
if (!(filter instanceof Identifier)) {
throw new SyntaxError(`Expected identifier for the test`);
}
operand = new TestExpression(operand, negate, filter);
}
return operand;
}
function parseFilterExpression() {
let operand = parseCallMemberExpression();
while (is(TOKEN_TYPES.Pipe)) {
++current;
let filter = parsePrimaryExpression();
if (!(filter instanceof Identifier)) {
throw new SyntaxError(`Expected identifier for the filter`);
}
if (is(TOKEN_TYPES.OpenParen)) {
filter = parseCallExpression(filter);
}
operand = new FilterExpression(operand, filter);
}
return operand;
}
function parsePrimaryExpression() {
const token = tokens[current];
switch (token.type) {
case TOKEN_TYPES.NumericLiteral:
++current;
return new NumericLiteral(Number(token.value));
case TOKEN_TYPES.StringLiteral:
++current;
return new StringLiteral(token.value);
case TOKEN_TYPES.BooleanLiteral:
++current;
return new BooleanLiteral(token.value.toLowerCase() === "true");
case TOKEN_TYPES.NullLiteral:
++current;
return new NullLiteral(null);
case TOKEN_TYPES.Identifier:
++current;
return new Identifier(token.value);
case TOKEN_TYPES.OpenParen: {
++current;
const expression = parseExpressionSequence();
if (tokens[current].type !== TOKEN_TYPES.CloseParen) {
throw new SyntaxError(`Expected closing parenthesis, got ${tokens[current].type} instead`);
}
++current;
return expression;
}
case TOKEN_TYPES.OpenSquareBracket: {
++current;
const values = [];
while (!is(TOKEN_TYPES.CloseSquareBracket)) {
values.push(parseExpression());
if (is(TOKEN_TYPES.Comma)) {
++current;
}
}
++current;
return new ArrayLiteral(values);
}
case TOKEN_TYPES.OpenCurlyBracket: {
++current;
const values = /* @__PURE__ */ new Map();
while (!is(TOKEN_TYPES.CloseCurlyBracket)) {
const key = parseExpression();
expect(TOKEN_TYPES.Colon, "Expected colon between key and value in object literal");
const value = parseExpression();
values.set(key, value);
if (is(TOKEN_TYPES.Comma)) {
++current;
}
}
++current;
return new ObjectLiteral(values);
}
default:
throw new SyntaxError(`Unexpected token: ${token.type}`);
}
}
while (current < tokens.length) {
program.body.push(parseAny());
}
return program;
}
// src/utils.ts
function range(start, stop, step = 1) {
if (stop === void 0) {
stop = start;
start = 0;
}
const result = [];
for (let i = start; i < stop; i += step) {
result.push(i);
}
return result;
}
function slice(array, start, stop, step = 1) {
const direction = Math.sign(step);
if (direction >= 0) {
start = (start ??= 0) < 0 ? Math.max(array.length + start, 0) : Math.min(start, array.length);
stop = (stop ??= array.length) < 0 ? Math.max(array.length + stop, 0) : Math.min(stop, array.length);
} else {
start = (start ??= array.length - 1) < 0 ? Math.max(array.length + start, -1) : Math.min(start, array.length - 1);
stop = (stop ??= -1) < -1 ? Math.max(array.length + stop, -1) : Math.min(stop, array.length - 1);
}
const result = [];
for (let i = start; direction * i < direction * stop; i += step) {
result.push(array[i]);
}
return result;
}
function titleCase(value) {
return value.replace(/\b\w/g, (c) => c.toUpperCase());
}
// src/runtime.ts
var RuntimeValue = class {
type = "RuntimeValue";
value;
/**
* A collection of built-in functions for this type.
*/
builtins = /* @__PURE__ */ new Map();
/**
* Creates a new RuntimeValue.
*/
constructor(value = void 0) {
this.value = value;
}
/**
* Determines truthiness or falsiness of the runtime value.
* This function should be overridden by subclasses if it has custom truthiness criteria.
* @returns {BooleanValue} BooleanValue(true) if the value is truthy, BooleanValue(false) otherwise.
*/
__bool__() {
return new BooleanValue(!!this.value);
}
};
var NumericValue = class extends RuntimeValue {
type = "NumericValue";
};
var StringValue = class extends RuntimeValue {
type = "StringValue";
builtins = /* @__PURE__ */ new Map([
[
"upper",
new FunctionValue(() => {
return new StringValue(this.value.toUpperCase());
})
],
[
"lower",
new FunctionValue(() => {
return new StringValue(this.value.toLowerCase());
})
],
[
"strip",
new FunctionValue(() => {
return new StringValue(this.value.trim());
})
],
[
"title",
new FunctionValue(() => {
return new StringValue(titleCase(this.value));
})
],
["length", new NumericValue(this.value.length)],
[
"rstrip",
new FunctionValue(() => {
return new StringValue(this.value.trimEnd());
})
],
[
"lstrip",
new FunctionValue(() => {
return new StringValue(this.value.trimStart());
})
]
]);
};
var BooleanValue = class extends RuntimeValue {
type = "BooleanValue";
};
var ObjectValue = class extends RuntimeValue {
type = "ObjectValue";
/**
* NOTE: necessary to override since all JavaScript arrays are considered truthy,
* while only non-empty Python arrays are consider truthy.
*
* e.g.,
* - JavaScript: {} && 5 -> 5
* - Python: {} and 5 -> {}
*/
__bool__() {
return new BooleanValue(this.value.size > 0);
}
builtins = /* @__PURE__ */ new Map([
[
"get",
new FunctionValue(([key, defaultValue]) => {
if (!(key instanceof StringValue)) {
throw new Error(`Object key must be a string: got ${key.type}`);
}
return this.value.get(key.value) ?? defaultValue ?? new NullValue();
})
],
[
"items",
new FunctionValue(() => {
return new ArrayValue(
Array.from(this.value.entries()).map(([key, value]) => new ArrayValue([new StringValue(key), value]))
);
})
]
]);
};
var KeywordArgumentsValue = class extends ObjectValue {
type = "KeywordArgumentsValue";
};
var ArrayValue = class extends RuntimeValue {
type = "ArrayValue";
builtins = /* @__PURE__ */ new Map([["length", new NumericValue(this.value.length)]]);
/**
* NOTE: necessary to override since all JavaScript arrays are considered truthy,
* while only non-empty Python arrays are consider truthy.
*
* e.g.,
* - JavaScript: [] && 5 -> 5
* - Python: [] and 5 -> []
*/
__bool__() {
return new BooleanValue(this.value.length > 0);
}
};
var TupleValue = class extends ArrayValue {
type = "TupleValue";
};
var FunctionValue = class extends RuntimeValue {
type = "FunctionValue";
};
var NullValue = class extends RuntimeValue {
type = "NullValue";
};
var UndefinedValue = class extends RuntimeValue {
type = "UndefinedValue";
};
var Environment = class {
constructor(parent) {
this.parent = parent;
}
/**
* The variables declared in this environment.
*/
variables = /* @__PURE__ */ new Map([
[
"namespace",
new FunctionValue((args) => {
if (args.length === 0) {
return new ObjectValue(/* @__PURE__ */ new Map());
}
if (args.length !== 1 || !(args[0] instanceof ObjectValue)) {
throw new Error("`namespace` expects either zero arguments or a single object argument");
}
return args[0];
})
]
]);
/**
* The tests available in this environment.
*/
tests = /* @__PURE__ */ new Map([
["boolean", (operand) => operand.type === "BooleanValue"],
["callable", (operand) => operand instanceof FunctionValue],
[
"odd",
(operand) => {
if (operand.type !== "NumericValue") {
throw new Error(`Cannot apply test "odd" to type: ${operand.type}`);
}
return operand.value % 2 !== 0;
}
],
[
"even",
(operand) => {
if (operand.type !== "NumericValue") {
throw new Error(`Cannot apply test "even" to type: ${operand.type}`);
}
return operand.value % 2 === 0;
}
],
["false", (operand) => operand.type === "BooleanValue" && !operand.value],
["true", (operand) => operand.type === "BooleanValue" && operand.value],
["none", (operand) => operand.type === "NullValue"],
["string", (operand) => operand.type === "StringValue"],
["number", (operand) => operand.type === "NumericValue"],
["integer", (operand) => operand.type === "NumericValue" && Number.isInteger(operand.value)],
["iterable", (operand) => operand.type === "ArrayValue" || operand.type === "StringValue"],
["mapping", (operand) => operand.type === "ObjectValue"],
[
"lower",
(operand) => {
const str = operand.value;
return operand.type === "StringValue" && str === str.toLowerCase();
}
],
[
"upper",
(operand) => {
const str = operand.value;
return operand.type === "StringValue" && str === str.toUpperCase();
}
],
["none", (operand) => operand.type === "NullValue"],
["defined", (operand) => operand.type !== "UndefinedValue"],
["undefined", (operand) => operand.type === "UndefinedValue"],
["equalto", (a, b) => a.value === b.value],
["eq", (a, b) => a.value === b.value]
]);
/**
* Set the value of a variable in the current environment.
*/
set(name, value) {
return this.declareVariable(name, convertToRuntimeValues(value));
}
declareVariable(name, value) {
if (this.variables.has(name)) {
throw new SyntaxError(`Variable already declared: ${name}`);
}
this.variables.set(name, value);
return value;
}
// private assignVariable(name: string, value: AnyRuntimeValue): AnyRuntimeValue {
// const env = this.resolve(name);
// env.variables.set(name, value);
// return value;
// }
/**
* Set variable in the current scope.
* See https://jinja.palletsprojects.com/en/3.0.x/templates/#assignments for more information.
*/
setVariable(name, value) {
this.variables.set(name, value);
return value;
}
/**
* Resolve the environment in which the variable is declared.
* @param {string} name The name of the variable.
* @returns {Environment} The environment in which the variable is declared.
*/
resolve(name) {
if (this.variables.has(name)) {
return this;
}
if (this.parent) {
return this.parent.resolve(name);
}
throw new Error(`Unknown variable: ${name}`);
}
lookupVariable(name) {
try {
return this.resolve(name).variables.get(name) ?? new UndefinedValue();
} catch {
return new UndefinedValue();
}
}
};
var Interpreter = class {
global;
constructor(env) {
this.global = env ?? new Environment();
}
/**
* Run the program.
*/
run(program) {
return this.evaluate(program, this.global);
}
/**
* Evaluates expressions following the binary operation type.
*/
evaluateBinaryExpression(node, environment) {
const left = this.evaluate(node.left, environment);
switch (node.operator.value) {
case "and":
return left.__bool__().value ? this.evaluate(node.right, environment) : left;
case "or":
return left.__bool__().value ? left : this.evaluate(node.right, environment);
}
const right = this.evaluate(node.right, environment);
switch (node.operator.value) {
case "==":
return new BooleanValue(left.value == right.value);
case "!=":
return new BooleanValue(left.value != right.value);
}
if (left instanceof UndefinedValue || right instanceof UndefinedValue) {
throw new Error("Cannot perform operation on undefined values");
} else if (left instanceof NullValue || right instanceof NullValue) {
throw new Error("Cannot perform operation on null values");
} else if (left instanceof NumericValue && right instanceof NumericValue) {
switch (node.operator.value) {
case "+":
return new NumericValue(left.value + right.value);
case "-":
return new NumericValue(left.value - right.value);
case "*":
return new NumericValue(left.value * right.value);
case "/":
return new NumericValue(left.value / right.value);
case "%":
return new NumericValue(left.value % right.value);
case "<":
return new BooleanValue(left.value < right.value);
case ">":
return new BooleanValue(left.value > right.value);
case ">=":
return new BooleanValue(left.value >= right.value);
case "<=":
return new BooleanValue(left.value <= right.value);
}
} else if (left instanceof ArrayValue && right instanceof ArrayValue) {
switch (node.operator.value) {
case "+":
return new ArrayValue(left.value.concat(right.value));
}
} else if (right instanceof ArrayValue) {
const member = right.value.find((x) => x.value === left.value) !== void 0;
switch (node.operator.value) {
case "in":
return new BooleanValue(member);
case "not in":
return new BooleanValue(!member);
}
}
if (left instanceof StringValue || right instanceof StringValue) {
switch (node.operator.value) {
case "+":
return new StringValue(left.value.toString() + right.value.toString());
}
}
if (left instanceof StringValue && right instanceof StringValue) {
switch (node.operator.value) {
case "in":
return new BooleanValue(right.value.includes(left.value));
case "not in":
return new BooleanValue(!right.value.includes(left.value));
}
}
if (left instanceof StringValue && right instanceof ObjectValue) {
switch (node.operator.value) {
case "in":
return new BooleanValue(right.value.has(left.value));
case "not in":
return new BooleanValue(!right.value.has(left.value));
}
}
throw new SyntaxError(`Unknown operator "${node.operator.value}" between ${left.type} and ${right.type}`);
}
evaluateArguments(args, environment) {
const positionalArguments = [];
const keywordArguments = /* @__PURE__ */ new Map();
for (const argument of args) {
if (argument.type === "KeywordArgumentExpression") {
const kwarg = argument;
keywordArguments.set(kwarg.key.value, this.evaluate(kwarg.value, environment));
} else {
if (keywordArguments.size > 0) {
throw new Error("Positional arguments must come before keyword arguments");
}
positionalArguments.push(this.evaluate(argument, environment));
}
}
return [positionalArguments, keywordArguments];
}
/**
* Evaluates expressions following the filter operation type.
*/
evaluateFilterExpression(node, environment) {
const operand = this.evaluate(node.operand, environment);
if (node.filter.type === "Identifier") {
const filter = node.filter;
if (filter.value === "tojson") {
return new StringValue(toJSON(operand));
}
if (operand instanceof ArrayValue) {
switch (filter.value) {
case "list":
return operand;
case "first":
return operand.value[0];
case "last":
return operand.value[operand.value.length - 1];
case "length":
return new NumericValue(operand.value.length);
case "reverse":
return new ArrayValue(operand.value.reverse());
case "sort":
return new ArrayValue(
operand.value.sort((a, b) => {
if (a.type !== b.type) {
throw new Error(`Cannot compare different types: ${a.type} and ${b.type}`);
}
switch (a.type) {
case "NumericValue":
return a.value - b.value;
case "StringValue":
return a.value.localeCompare(b.value);
default:
throw new Error(`Cannot compare type: ${a.type}`);
}
})
);
default:
throw new Error(`Unknown ArrayValue filter: ${filter.value}`);
}
} else if (operand instanceof StringValue) {
switch (filter.value) {
case "length":
return new NumericValue(operand.value.length);
case "upper":
return new StringValue(operand.value.toUpperCase());
case "lower":
return new StringValue(operand.value.toLowerCase());
case "title":
return new StringValue(titleCase(operand.value));
case "capitalize":
return new StringValue(operand.value.charAt(0).toUpperCase() + operand.value.slice(1));
case "trim":
return new StringValue(operand.value.trim());
case "indent":
return new StringValue(
operand.value.split("\n").map(
(x, i) => (
// By default, don't indent the first line or empty lines
i === 0 || x.length === 0 ? x : " " + x
)
).join("\n")
);
case "string":
return operand;
default:
throw new Error(`Unknown StringValue filter: ${filter.value}`);
}
} else if (operand instanceof NumericValue) {
switch (filter.value) {
case "abs":
return new NumericValue(Math.abs(operand.value));
default:
throw new Error(`Unknown NumericValue filter: ${filter.value}`);
}
} else if (operand instanceof ObjectValue) {
switch (filter.value) {
case "items":
return new ArrayValue(
Array.from(operand.value.entries()).map(([key, value]) => new ArrayValue([new StringValue(key), value]))
);
case "length":
return new NumericValue(operand.value.size);
default:
throw new Error(`Unknown ObjectValue filter: ${filter.value}`);
}
}
throw new Error(`Cannot apply filter "${filter.value}" to type: ${operand.type}`);
} else if (node.filter.type === "CallExpression") {
const filter = node.filter;
if (filter.callee.type !== "Identifier") {
throw new Error(`Unknown filter: ${filter.callee.type}`);
}
const filterName = filter.callee.value;
if (filterName === "tojson") {
const [, kwargs] = this.evaluateArguments(filter.args, environment);
const indent = kwargs.get("indent") ?? new NullValue();
if (!(indent instanceof NumericValue || indent instanceof NullValue)) {
throw new Error("If set, indent must be a number");
}
return new StringValue(toJSON(operand, indent.value));
}
if (operand instanceof ArrayValue) {
switch (filterName) {
case "selectattr":
case "rejectattr": {
const select = filterName === "selectattr";
if (operand.value.some((x) => !(x instanceof ObjectValue))) {
throw new Error(`\`${filterName}\` can only be applied to array of objects`);
}
if (filter.args.some((x) => x.type !== "StringLiteral")) {
throw new Error(`arguments of \`${filterName}\` must be strings`);
}
const [attr, testName, value] = filter.args.map((x) => this.evaluate(x, environment));
let testFunction;
if (testName) {
const test = environment.tests.get(testName.value);
if (!test) {
throw new Error(`Unknown test: ${testName.value}`);
}
testFunction = test;
} else {
testFunction = (...x) => x[0].__bool__().value;
}
const filtered = operand.value.filter((item) => {
const a = item.value.get(attr.value);
const result = a ? testFunction(a, value) : false;
return select ? result : !result;
});
return new ArrayValue(filtered);
}
case "map": {
const [, kwargs] = this.evaluateArguments(filter.args, environment);
if (kwargs.has("attribute")) {
const attr = kwargs.get("attribute");
if (!(attr instanceof StringValue)) {
throw new Error("attribute must be a string");
}
const defaultValue = kwargs.get("default");
const mapped = operand.value.map((item) => {
if (!(item instanceof ObjectValue)) {
throw new Error("items in map must be an object");
}
return item.value.get(attr.value) ?? defaultValue ?? new UndefinedValue();
});
return new ArrayValue(mapped);
} else {
throw new Error("`map` expressions without `attribute` set are not currently supported.");
}
}
}
throw new Error(`Unknown ArrayValue filter: ${filterName}`);
} else if (operand instanceof StringValue) {
switch (filterName) {
case "indent": {
const [args, kwargs] = this.evaluateArguments(filter.args, environment);
const width = args.at(0) ?? kwargs.get("width") ?? new NumericValue(4);
if (!(width instanceof NumericValue)) {
throw new Error("width must be a number");
}
const first = args.at(1) ?? kwargs.get("first") ?? new BooleanValue(false);
const blank = args.at(2) ?? kwargs.get("blank") ?? new BooleanValue(false);
const lines = operand.value.split("\n");
const indent = " ".repeat(width.value);
const indented = lines.map(
(x, i) => !first.value && i === 0 || !blank.value && x.length === 0 ? x : indent + x
);
return new StringValue(indented.join("\n"));
}
}
throw new Error(`Unknown StringValue filter: ${filterName}`);
} else {
throw new Error(`Cannot apply filter "${filterName}" to type: ${operand.type}`);
}
}
throw new Error(`Unknown filter: ${node.filter.type}`);
}
/**
* Evaluates expressions following the test operation type.
*/
evaluateTestExpression(node, environment) {
const operand = this.evaluate(node.operand, environment);
const test = environment.tests.get(node.test.value);
if (!test) {
throw new Error(`Unknown test: ${node.test.value}`);
}
const result = test(operand);
return new BooleanValue(node.negate ? !result : result);
}
/**
* Evaluates expressions following the unary operation type.
*/
evaluateUnaryExpression(node, environment) {
const argument = this.evaluate(node.argument, environment);
switch (node.operator.value) {
case "not":
return new BooleanValue(!argument.value);
default:
throw new SyntaxError(`Unknown operator: ${node.operator.value}`);
}
}
evalProgram(program, environment) {
return this.evaluateBlock(program.body, environment);
}
evaluateBlock(statements, environment) {
let result = "";
for (const statement of statements) {
const lastEvaluated = this.evaluate(statement, environment);
if (lastEvaluated.type !== "NullValue" && lastEvaluated.type !== "UndefinedValue") {
result += lastEvaluated.value;
}
}
return new StringValue(result);
}
evaluateIdentifier(node, environment) {
return environment.lookupVariable(node.value);
}
evaluateCallExpression(expr, environment) {
const [args, kwargs] = this.evaluateArguments(expr.args, environment);
if (kwargs.size > 0) {
args.push(new KeywordArgumentsValue(kwargs));
}
const fn = this.evaluate(expr.callee, environment);
if (fn.type !== "FunctionValue") {
throw new Error(`Cannot call something that is not a function: got ${fn.type}`);
}
return fn.value(args, environment);
}
evaluateSliceExpression(object, expr, environment) {
if (!(object instanceof ArrayValue || object instanceof StringValue)) {
throw new Error("Slice object must be an array or string");
}
const start = this.evaluate(expr.start, environment);
const stop = this.evaluate(expr.stop, environment);
const step = this.evaluate(expr.step, environment);
if (!(start instanceof NumericValue || start instanceof UndefinedValue)) {
throw new Error("Slice start must be numeric or undefined");
}
if (!(stop instanceof NumericValue || stop instanceof UndefinedValue)) {
throw new Error("Slice stop must be numeric or undefined");
}
if (!(step instanceof NumericValue || step instanceof UndefinedValue)) {
throw new Error("Slice step must be numeric or undefined");
}
if (object instanceof ArrayValue) {
return new ArrayValue(slice(object.value, start.value, stop.value, step.value));
} else {
return new StringValue(slice(Array.from(object.value), start.value, stop.value, step.value).join(""));
}
}
evaluateMemberExpression(expr, environment) {
const object = this.evaluate(expr.object, environment);
let property;
if (expr.computed) {
if (expr.property.type === "SliceExpression") {
return this.evaluateSliceExpression(object, expr.property, environment);
} else {
property = this.evaluate(expr.property, environment);
}
} else {
property = new StringValue(expr.property.value);
}
let value;
if (object instanceof ObjectValue) {
if (!(property instanceof StringValue)) {
throw new Error(`Cannot access property with non-string: got ${property.type}`);
}
value = object.value.get(property.value) ?? object.builtins.get(property.value);
} else if (object instanceof ArrayValue || object instanceof StringValue) {
if (property instanceof NumericValue) {
value = object.value.at(property.value);
if (object instanceof StringValue) {
value = new StringValue(object.value.at(property.value));
}
} else if (property instanceof StringValue) {
value = object.builtins.get(property.value);
} else {
throw new Error(`Cannot access property with non-string/non-number: got ${property.type}`);
}
} else {
if (!(property instanceof StringValue)) {
throw new Error(`Cannot access property with non-string: got ${property.type}`);
}
value = object.builtins.get(property.value);
}
return value instanceof RuntimeValue ? value : new UndefinedValue();
}
evaluateSet(node, environment) {
const rhs = this.evaluate(node.value, environment);
if (node.assignee.type === "Identifier") {
const variableName = node.assignee.value;
environment.setVariable(variableName, rhs);
} else if (node.assignee.type === "MemberExpression") {
const member = node.assignee;
const object = this.evaluate(member.object, environment);
if (!(object instanceof ObjectValue)) {
throw new Error("Cannot assign to member of non-object");
}
if (member.property.type !== "Identifier") {
throw new Error("Cannot assign to member with non-identifier property");
}
object.value.set(member.property.value, rhs);
} else {
throw new Error(`Invalid LHS inside assignment expression: ${JSON.stringify(node.assignee)}`);
}
return new NullValue();
}
evaluateIf(node, environment) {
const test = this.evaluate(node.test, environment);
return this.evaluateBlock(test.__bool__().value ? node.body : node.alternate, environment);
}
evaluateFor(node, environment) {
const scope = new Environment(environment);
let test, iterable;
if (node.iterable.type === "SelectExpression") {
const select = node.iterable;
iterable = this.evaluate(select.iterable, scope);
test = select.test;
} else {
iterable = this.evaluate(node.iterable, scope);
}
if (!(iterable instanceof ArrayValue)) {
throw new Error(`Expected iterable type in for loop: got ${iterable.type}`);
}
const items = [];
const scopeUpdateFunctions = [];
for (let i = 0; i < iterable.value.length; ++i) {
const loopScope = new Environment(scope);
const current = iterable.value[i];
let scopeUpdateFunction;
if (node.loopvar.type === "Identifier") {
scopeUpdateFunction = (scope2) => scope2.setVariable(node.loopvar.value, current);
} else if (node.loopvar.type === "TupleLiteral") {
const loopvar = node.loopvar;
if (current.type !== "ArrayValue") {
throw new Error(`Cannot unpack non-iterable type: ${current.type}`);
}
const c = current;
if (loopvar.value.length !== c.value.length) {
throw new Error(`Too ${loopvar.value.length > c.value.length ? "few" : "many"} items to unpack`);
}
scopeUpdateFunction = (scope2) => {
for (let j = 0; j < loopvar.value.length; ++j) {
if (loopvar.value[j].type !== "Identifier") {
throw new Error(`Cannot unpack non-identifier type: ${loopvar.value[j].type}`);
}
scope2.setVariable(loopvar.value[j].value, c.value[j]);
}
};
} else {
throw new Error(`Invalid loop variable(s): ${node.loopvar.type}`);
}
if (test) {
scopeUpdateFunction(loopScope);
const testValue = this.evaluate(test, loopScope);
if (!testValue.__bool__().value) {
continue;
}
}
items.push(current);
scopeUpdateFunctions.push(scopeUpdateFunction);
}
let result = "";
let noIteration = true;
for (let i = 0; i < items.length; ++i) {
const loop = /* @__PURE__ */ new Map([
["index", new NumericValue(i + 1)],
["index0", new NumericValue(i)],
["revindex", new NumericValue(items.length - i)],
["revindex0", new NumericValue(items.length - i - 1)],
["first", new BooleanValue(i === 0)],
["last", new BooleanValue(i === items.length - 1)],
["length", new NumericValue(items.length)],
["previtem", i > 0 ? items[i - 1] : new UndefinedValue()],
["nextitem", i < items.length - 1 ? items[i + 1] : new UndefinedValue()]
]);
scope.setVariable("loop", new ObjectValue(loop));
scopeUpdateFunctions[i](scope);
const evaluated = this.evaluateBlock(node.body, scope);
result += evaluated.value;
noIteration = false;
}
if (noIteration) {
const defaultEvaluated = this.evaluateBlock(node.defaultBlock, scope);
result += defaultEvaluated.value;
}
return new StringValue(result);
}
/**
* See https://jinja.palletsprojects.com/en/3.1.x/templates/#macros for more information.
*/
evaluateMacro(node, environment) {
environment.setVariable(
node.name.value,
new FunctionValue((args, scope) => {
const macroScope = new Environment(scope);
args = args.slice();
let kwargs;
if (args.at(-1)?.type === "KeywordArgumentsValue") {
kwargs = args.pop();
}
for (let i = 0; i < node.args.length; ++i) {
const nodeArg = node.args[i];
const passedArg = args[i];
if (nodeArg.type === "Identifier") {
const identifier = nodeArg;
if (!passedArg) {
throw new Error(`Missing positional argument: ${identifier.value}`);
}
macroScope.setVariable(identifier.value, passedArg);
} else if (nodeArg.type === "KeywordArgumentExpression") {
const kwarg = nodeArg;
const value = passedArg ?? // Try positional arguments first
kwargs?.value.get(kwarg.key.value) ?? // Look in user-passed kwargs
this.evaluate(kwarg.value, macroScope);
macroScope.setVariable(kwarg.key.value, value);
} else {
throw new Error(`Unknown argument type: ${nodeArg.type}`);
}
}
return this.evaluateBlock(node.body, macroScope);
})
);
return new NullValue();
}
evaluate(statement, environment) {
if (statement === void 0)
return new UndefinedValue();
switch (statement.type) {
case "Program":
return this.evalProgram(statement, environment);
case "Set":
return this.evaluateSet(statement, environment);
case "If":
return this.evaluateIf(statement, environment);
case "For":
return this.evaluateFor(statement, environment);
case "Macro":
return this.evaluateMacro(statement, environment);
case "NumericLiteral":
return new NumericValue(Number(statement.value));
case "StringLiteral":
return new StringValue(statement.value);
case "BooleanLiteral":
return new BooleanValue(statement.value);
case "NullLiteral":
return new NullValue(statement.value);
case "ArrayLiteral":
return new ArrayValue(statement.value.map((x) => this.evaluate(x, environment)));
case "TupleLiteral":
return new TupleValue(statement.value.map((x) => this.evaluate(x, environment)));
case "ObjectLiteral": {
const mapping = /* @__PURE__ */ new Map();
for (const [key, value] of statement.value) {
const evaluatedKey = this.evaluate(key, environment);
if (!(evaluatedKey instanceof StringValue)) {
throw new Error(`Object keys must be strings: got ${evaluatedKey.type}`);
}
mapping.set(evaluatedKey.value, this.evaluate(value, environment));
}
return new ObjectValue(mapping);
}
case "Identifier":
return this.evaluateIdentifier(statement, environment);
case "CallExpression":
return this.evaluateCallExpression(statement, environment);
case "MemberExpression":
return this.evaluateMemberExpression(statement, environment);
case "UnaryExpression":
return this.evaluateUnaryExpression(statement, environment);
case "BinaryExpression":
return this.evaluateBinaryExpression(statement, environment);
case "FilterExpression":
return this.evaluateFilterExpression(statement, environment);
case "TestExpression":
return this.evaluateTestExpression(statement, environment);
default:
throw new SyntaxError(`Unknown node type: ${statement.type}`);
}
}
};
function convertToRuntimeValues(input) {
switch (typeof input) {
case "number":
return new NumericValue(input);
case "string":
return new StringValue(input);
case "boolean":
return new BooleanValue(input);
case "undefined":
return new UndefinedValue();
case "object":
if (input === null) {
return new NullValue();
} else if (Array.isArray(input)) {
return new ArrayValue(input.map(convertToRuntimeValues));
} else {
return new ObjectValue(
new Map(Object.entries(input).map(([key, value]) => [key, convertToRuntimeValues(value)]))
);
}
case "function":
return new FunctionValue((args, _scope) => {
const result = input(...args.map((x) => x.value)) ?? null;
return convertToRuntimeValues(result);
});
default:
throw new Error(`Cannot convert to runtime value: ${input}`);
}
}
function toJSON(input, indent, depth) {
const currentDepth = depth ?? 0;
switch (input.type) {
case "NullValue":
case "UndefinedValue":
return "null";
case "NumericValue":
case "StringValue":
case "BooleanValue":
return JSON.stringify(input.value);
case "ArrayValue":
case "ObjectValue": {
const indentValue = indent ? " ".repeat(indent) : "";
const basePadding = "\n" + indentValue.repeat(currentDepth);
const childrenPadding = basePadding + indentValue;
if (input.type === "ArrayValue") {
const core = input.value.map((x) => toJSON(x, indent, currentDepth + 1));
return indent ? `[${childrenPadding}${core.join(`,${childrenPadding}`)}${basePadding}]` : `[${core.join(", ")}]`;
} else {
const core = Array.from(input.value.entries()).map(([key, value]) => {
const v = `"${key}": ${toJSON(value, indent, currentDepth + 1)}`;
return indent ? `${childrenPadding}${v}` : v;
});
return indent ? `{${core.join(",")}${basePadding}}` : `{${core.join(", ")}}`;
}
}
default:
throw new Error(`Cannot convert to JSON: ${input.type}`);
}
}
// src/index.ts
var Template = class {
parsed;
/**
* @param {string} template The template string
*/
constructor(template) {
const tokens = tokenize(template, {
lstrip_blocks: true,
trim_blocks: true
});
this.parsed = parse(tokens);
}
render(items) {
const env = new Environment();
env.set("false", false);
env.set("true", true);
env.set("raise_exception", (args) => {
throw new Error(args);
});
env.set("range", range);
for (const [key, value] of Object.entries(items)) {
env.set(key, value);
}
const interpreter = new Interpreter(env);
const result = interpreter.run(this.parsed);
return result.value;
}
};
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/backend-impl.js":
/*!******************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/backend-impl.js ***!
\******************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ registerBackend: () => (/* binding */ registerBackend),
/* harmony export */ resolveBackendAndExecutionProviders: () => (/* binding */ resolveBackendAndExecutionProviders)
/* harmony export */ });
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
const backends = new Map();
const backendsSortedByPriority = [];
/**
* Register a backend.
*
* @param name - the name as a key to lookup as an execution provider.
* @param backend - the backend object.
* @param priority - an integer indicating the priority of the backend. Higher number means higher priority. if priority
* < 0, it will be considered as a 'beta' version and will not be used as a fallback backend by default.
*
* @ignore
*/
const registerBackend = (name, backend, priority) => {
if (backend && typeof backend.init === 'function' && typeof backend.createInferenceSessionHandler === 'function') {
const currentBackend = backends.get(name);
if (currentBackend === undefined) {
backends.set(name, { backend, priority });
}
else if (currentBackend.priority > priority) {
// same name is already registered with a higher priority. skip registeration.
return;
}
else if (currentBackend.priority === priority) {
if (currentBackend.backend !== backend) {
throw new Error(`cannot register backend "${name}" using priority ${priority}`);
}
}
if (priority >= 0) {
const i = backendsSortedByPriority.indexOf(name);
if (i !== -1) {
backendsSortedByPriority.splice(i, 1);
}
for (let i = 0; i < backendsSortedByPriority.length; i++) {
if (backends.get(backendsSortedByPriority[i]).priority <= priority) {
backendsSortedByPriority.splice(i, 0, name);
return;
}
}
backendsSortedByPriority.push(name);
}
return;
}
throw new TypeError('not a valid backend');
};
/**
* Try to resolve and initialize a backend.
*
* @param backendName - the name of the backend.
* @returns the backend instance if resolved and initialized successfully, or an error message if failed.
*/
const tryResolveAndInitializeBackend = async (backendName) => {
const backendInfo = backends.get(backendName);
if (!backendInfo) {
return 'backend not found.';
}
if (backendInfo.initialized) {
return backendInfo.backend;
}
else if (backendInfo.aborted) {
return backendInfo.error;
}
else {
const isInitializing = !!backendInfo.initPromise;
try {
if (!isInitializing) {
backendInfo.initPromise = backendInfo.backend.init(backendName);
}
await backendInfo.initPromise;
backendInfo.initialized = true;
return backendInfo.backend;
}
catch (e) {
if (!isInitializing) {
backendInfo.error = `${e}`;
backendInfo.aborted = true;
}
return backendInfo.error;
}
finally {
delete backendInfo.initPromise;
}
}
};
/**
* Resolve execution providers from the specific session options.
*
* @param options - the session options object.
* @returns a promise that resolves to a tuple of an initialized backend instance and a session options object with
* filtered EP list.
*
* @ignore
*/
const resolveBackendAndExecutionProviders = async (options) => {
// extract backend hints from session options
const eps = options.executionProviders || [];
const backendHints = eps.map((i) => (typeof i === 'string' ? i : i.name));
const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints;
// try to resolve and initialize all requested backends
let backend;
const errors = [];
const availableBackendNames = new Set();
for (const backendName of backendNames) {
const resolveResult = await tryResolveAndInitializeBackend(backendName);
if (typeof resolveResult === 'string') {
errors.push({ name: backendName, err: resolveResult });
}
else {
if (!backend) {
backend = resolveResult;
}
if (backend === resolveResult) {
availableBackendNames.add(backendName);
}
}
}
// if no backend is available, throw error.
if (!backend) {
throw new Error(`no available backend found. ERR: ${errors.map((e) => `[${e.name}] ${e.err}`).join(', ')}`);
}
// for each explicitly requested backend, if it's not available, output warning message.
for (const { name, err } of errors) {
if (backendHints.includes(name)) {
// eslint-disable-next-line no-console
console.warn(`removing requested execution provider "${name}" from session options because it is not available: ${err}`);
}
}
const filteredEps = eps.filter((i) => availableBackendNames.has(typeof i === 'string' ? i : i.name));
return [
backend,
new Proxy(options, {
get: (target, prop) => {
if (prop === 'executionProviders') {
return filteredEps;
}
return Reflect.get(target, prop);
},
}),
];
};
//# sourceMappingURL=backend-impl.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/backend.js":
/*!*************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/backend.js ***!
\*************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ registerBackend: () => (/* reexport safe */ _backend_impl_js__WEBPACK_IMPORTED_MODULE_0__.registerBackend)
/* harmony export */ });
/* harmony import */ var _backend_impl_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./backend-impl.js */ "./node_modules/onnxruntime-common/dist/esm/backend-impl.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//# sourceMappingURL=backend.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/env-impl.js":
/*!**************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/env-impl.js ***!
\**************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ env: () => (/* binding */ env)
/* harmony export */ });
/* harmony import */ var _version_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./version.js */ "./node_modules/onnxruntime-common/dist/esm/version.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
let logLevelValue = 'warning';
const env = {
wasm: {},
webgl: {},
webgpu: {},
versions: { common: _version_js__WEBPACK_IMPORTED_MODULE_0__.version },
set logLevel(value) {
if (value === undefined) {
return;
}
if (typeof value !== 'string' || ['verbose', 'info', 'warning', 'error', 'fatal'].indexOf(value) === -1) {
throw new Error(`Unsupported logging level: ${value}`);
}
logLevelValue = value;
},
get logLevel() {
return logLevelValue;
},
};
// set property 'logLevel' so that they can be correctly transferred to worker by `postMessage()`.
Object.defineProperty(env, 'logLevel', { enumerable: true });
//# sourceMappingURL=env-impl.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/env.js":
/*!*********************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/env.js ***!
\*********************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ env: () => (/* binding */ env)
/* harmony export */ });
/* harmony import */ var _env_impl_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./env-impl.js */ "./node_modules/onnxruntime-common/dist/esm/env-impl.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
/**
* Represent a set of flags as a global singleton.
*/
const env = _env_impl_js__WEBPACK_IMPORTED_MODULE_0__.env;
//# sourceMappingURL=env.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/index.js":
/*!***********************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/index.js ***!
\***********************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ InferenceSession: () => (/* reexport safe */ _inference_session_js__WEBPACK_IMPORTED_MODULE_2__.InferenceSession),
/* harmony export */ TRACE: () => (/* reexport safe */ _trace_js__WEBPACK_IMPORTED_MODULE_6__.TRACE),
/* harmony export */ TRACE_FUNC_BEGIN: () => (/* reexport safe */ _trace_js__WEBPACK_IMPORTED_MODULE_6__.TRACE_FUNC_BEGIN),
/* harmony export */ TRACE_FUNC_END: () => (/* reexport safe */ _trace_js__WEBPACK_IMPORTED_MODULE_6__.TRACE_FUNC_END),
/* harmony export */ Tensor: () => (/* reexport safe */ _tensor_js__WEBPACK_IMPORTED_MODULE_3__.Tensor),
/* harmony export */ TrainingSession: () => (/* reexport safe */ _training_session_js__WEBPACK_IMPORTED_MODULE_9__.TrainingSession),
/* harmony export */ env: () => (/* reexport safe */ _env_js__WEBPACK_IMPORTED_MODULE_1__.env),
/* harmony export */ registerBackend: () => (/* reexport safe */ _backend_js__WEBPACK_IMPORTED_MODULE_0__.registerBackend)
/* harmony export */ });
/* harmony import */ var _backend_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./backend.js */ "./node_modules/onnxruntime-common/dist/esm/backend.js");
/* harmony import */ var _env_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ./env.js */ "./node_modules/onnxruntime-common/dist/esm/env.js");
/* harmony import */ var _inference_session_js__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! ./inference-session.js */ "./node_modules/onnxruntime-common/dist/esm/inference-session.js");
/* harmony import */ var _tensor_js__WEBPACK_IMPORTED_MODULE_3__ = __webpack_require__(/*! ./tensor.js */ "./node_modules/onnxruntime-common/dist/esm/tensor.js");
/* harmony import */ var _tensor_conversion_js__WEBPACK_IMPORTED_MODULE_4__ = __webpack_require__(/*! ./tensor-conversion.js */ "./node_modules/onnxruntime-common/dist/esm/tensor-conversion.js");
/* harmony import */ var _tensor_factory_js__WEBPACK_IMPORTED_MODULE_5__ = __webpack_require__(/*! ./tensor-factory.js */ "./node_modules/onnxruntime-common/dist/esm/tensor-factory.js");
/* harmony import */ var _trace_js__WEBPACK_IMPORTED_MODULE_6__ = __webpack_require__(/*! ./trace.js */ "./node_modules/onnxruntime-common/dist/esm/trace.js");
/* harmony import */ var _onnx_model_js__WEBPACK_IMPORTED_MODULE_7__ = __webpack_require__(/*! ./onnx-model.js */ "./node_modules/onnxruntime-common/dist/esm/onnx-model.js");
/* harmony import */ var _onnx_value_js__WEBPACK_IMPORTED_MODULE_8__ = __webpack_require__(/*! ./onnx-value.js */ "./node_modules/onnxruntime-common/dist/esm/onnx-value.js");
/* harmony import */ var _training_session_js__WEBPACK_IMPORTED_MODULE_9__ = __webpack_require__(/*! ./training-session.js */ "./node_modules/onnxruntime-common/dist/esm/training-session.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
/**
* # ONNX Runtime JavaScript API
*
* ONNX Runtime JavaScript API is a unified API for all JavaScript usages, including the following NPM packages:
*
* - [onnxruntime-node](https://www.npmjs.com/package/onnxruntime-node)
* - [onnxruntime-web](https://www.npmjs.com/package/onnxruntime-web)
* - [onnxruntime-react-native](https://www.npmjs.com/package/onnxruntime-react-native)
*
* See also:
* - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript/)
* - [Inference examples](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/js)
*
* @packageDocumentation
*/
//# sourceMappingURL=index.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/inference-session-impl.js":
/*!****************************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/inference-session-impl.js ***!
\****************************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ InferenceSession: () => (/* binding */ InferenceSession)
/* harmony export */ });
/* harmony import */ var _backend_impl_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./backend-impl.js */ "./node_modules/onnxruntime-common/dist/esm/backend-impl.js");
/* harmony import */ var _tensor_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ./tensor.js */ "./node_modules/onnxruntime-common/dist/esm/tensor.js");
/* harmony import */ var _trace_js__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! ./trace.js */ "./node_modules/onnxruntime-common/dist/esm/trace.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
class InferenceSession {
constructor(handler) {
this.handler = handler;
}
async run(feeds, arg1, arg2) {
(0,_trace_js__WEBPACK_IMPORTED_MODULE_2__.TRACE_FUNC_BEGIN)();
const fetches = {};
let options = {};
// check inputs
if (typeof feeds !== 'object' || feeds === null || feeds instanceof _tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor || Array.isArray(feeds)) {
throw new TypeError("'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.");
}
let isFetchesEmpty = true;
// determine which override is being used
if (typeof arg1 === 'object') {
if (arg1 === null) {
throw new TypeError('Unexpected argument[1]: cannot be null.');
}
if (arg1 instanceof _tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor) {
throw new TypeError("'fetches' cannot be a Tensor");
}
if (Array.isArray(arg1)) {
if (arg1.length === 0) {
throw new TypeError("'fetches' cannot be an empty array.");
}
isFetchesEmpty = false;
// output names
for (const name of arg1) {
if (typeof name !== 'string') {
throw new TypeError("'fetches' must be a string array or an object.");
}
if (this.outputNames.indexOf(name) === -1) {
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
}
fetches[name] = null;
}
if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
}
else if (typeof arg2 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
}
else {
// decide whether arg1 is fetches or options
// if any output name is present and its value is valid OnnxValue, we consider it fetches
let isFetches = false;
const arg1Keys = Object.getOwnPropertyNames(arg1);
for (const name of this.outputNames) {
if (arg1Keys.indexOf(name) !== -1) {
const v = arg1[name];
if (v === null || v instanceof _tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor) {
isFetches = true;
isFetchesEmpty = false;
fetches[name] = v;
}
}
}
if (isFetches) {
if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
}
else if (typeof arg2 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
}
else {
options = arg1;
}
}
}
else if (typeof arg1 !== 'undefined') {
throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'.");
}
// check if all inputs are in feed
for (const name of this.inputNames) {
if (typeof feeds[name] === 'undefined') {
throw new Error(`input '${name}' is missing in 'feeds'.`);
}
}
// if no fetches is specified, we use the full output names list
if (isFetchesEmpty) {
for (const name of this.outputNames) {
fetches[name] = null;
}
}
// feeds, fetches and options are prepared
const results = await this.handler.run(feeds, fetches, options);
const returnValue = {};
for (const key in results) {
if (Object.hasOwnProperty.call(results, key)) {
const result = results[key];
if (result instanceof _tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor) {
returnValue[key] = result;
}
else {
returnValue[key] = new _tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor(result.type, result.data, result.dims);
}
}
}
(0,_trace_js__WEBPACK_IMPORTED_MODULE_2__.TRACE_FUNC_END)();
return returnValue;
}
async release() {
return this.handler.dispose();
}
static async create(arg0, arg1, arg2, arg3) {
(0,_trace_js__WEBPACK_IMPORTED_MODULE_2__.TRACE_FUNC_BEGIN)();
// either load from a file or buffer
let filePathOrUint8Array;
let options = {};
if (typeof arg0 === 'string') {
filePathOrUint8Array = arg0;
if (typeof arg1 === 'object' && arg1 !== null) {
options = arg1;
}
else if (typeof arg1 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
}
else if (arg0 instanceof Uint8Array) {
filePathOrUint8Array = arg0;
if (typeof arg1 === 'object' && arg1 !== null) {
options = arg1;
}
else if (typeof arg1 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
}
else if (arg0 instanceof ArrayBuffer ||
(typeof SharedArrayBuffer !== 'undefined' && arg0 instanceof SharedArrayBuffer)) {
const buffer = arg0;
let byteOffset = 0;
let byteLength = arg0.byteLength;
if (typeof arg1 === 'object' && arg1 !== null) {
options = arg1;
}
else if (typeof arg1 === 'number') {
byteOffset = arg1;
if (!Number.isSafeInteger(byteOffset)) {
throw new RangeError("'byteOffset' must be an integer.");
}
if (byteOffset < 0 || byteOffset >= buffer.byteLength) {
throw new RangeError(`'byteOffset' is out of range [0, ${buffer.byteLength}).`);
}
byteLength = arg0.byteLength - byteOffset;
if (typeof arg2 === 'number') {
byteLength = arg2;
if (!Number.isSafeInteger(byteLength)) {
throw new RangeError("'byteLength' must be an integer.");
}
if (byteLength <= 0 || byteOffset + byteLength > buffer.byteLength) {
throw new RangeError(`'byteLength' is out of range (0, ${buffer.byteLength - byteOffset}].`);
}
if (typeof arg3 === 'object' && arg3 !== null) {
options = arg3;
}
else if (typeof arg3 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
}
else if (typeof arg2 !== 'undefined') {
throw new TypeError("'byteLength' must be a number.");
}
}
else if (typeof arg1 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
filePathOrUint8Array = new Uint8Array(buffer, byteOffset, byteLength);
}
else {
throw new TypeError("Unexpected argument[0]: must be 'path' or 'buffer'.");
}
// resolve backend, update session options with validated EPs, and create session handler
const [backend, optionsWithValidatedEPs] = await (0,_backend_impl_js__WEBPACK_IMPORTED_MODULE_0__.resolveBackendAndExecutionProviders)(options);
const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, optionsWithValidatedEPs);
(0,_trace_js__WEBPACK_IMPORTED_MODULE_2__.TRACE_FUNC_END)();
return new InferenceSession(handler);
}
startProfiling() {
this.handler.startProfiling();
}
endProfiling() {
this.handler.endProfiling();
}
get inputNames() {
return this.handler.inputNames;
}
get outputNames() {
return this.handler.outputNames;
}
}
//# sourceMappingURL=inference-session-impl.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/inference-session.js":
/*!***********************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/inference-session.js ***!
\***********************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ InferenceSession: () => (/* binding */ InferenceSession)
/* harmony export */ });
/* harmony import */ var _inference_session_impl_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./inference-session-impl.js */ "./node_modules/onnxruntime-common/dist/esm/inference-session-impl.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// eslint-disable-next-line @typescript-eslint/naming-convention
const InferenceSession = _inference_session_impl_js__WEBPACK_IMPORTED_MODULE_0__.InferenceSession;
//# sourceMappingURL=inference-session.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/onnx-model.js":
/*!****************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/onnx-model.js ***!
\****************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//# sourceMappingURL=onnx-model.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/onnx-value.js":
/*!****************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/onnx-value.js ***!
\****************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//# sourceMappingURL=onnx-value.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/tensor-conversion-impl.js":
/*!****************************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/tensor-conversion-impl.js ***!
\****************************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ tensorToDataURL: () => (/* binding */ tensorToDataURL),
/* harmony export */ tensorToImageData: () => (/* binding */ tensorToImageData)
/* harmony export */ });
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
/**
* implementation of Tensor.toDataURL()
*/
const tensorToDataURL = (tensor, options) => {
const canvas = typeof document !== 'undefined' ? document.createElement('canvas') : new OffscreenCanvas(1, 1);
canvas.width = tensor.dims[3];
canvas.height = tensor.dims[2];
const pixels2DContext = canvas.getContext('2d');
if (pixels2DContext != null) {
// Default values for height and width & format
let width;
let height;
if (options?.tensorLayout !== undefined && options.tensorLayout === 'NHWC') {
width = tensor.dims[2];
height = tensor.dims[3];
}
else {
// Default layout is NCWH
width = tensor.dims[3];
height = tensor.dims[2];
}
const inputformat = options?.format !== undefined ? options.format : 'RGB';
const norm = options?.norm;
let normMean;
let normBias;
if (norm === undefined || norm.mean === undefined) {
normMean = [255, 255, 255, 255];
}
else {
if (typeof norm.mean === 'number') {
normMean = [norm.mean, norm.mean, norm.mean, norm.mean];
}
else {
normMean = [norm.mean[0], norm.mean[1], norm.mean[2], 0];
if (norm.mean[3] !== undefined) {
normMean[3] = norm.mean[3];
}
}
}
if (norm === undefined || norm.bias === undefined) {
normBias = [0, 0, 0, 0];
}
else {
if (typeof norm.bias === 'number') {
normBias = [norm.bias, norm.bias, norm.bias, norm.bias];
}
else {
normBias = [norm.bias[0], norm.bias[1], norm.bias[2], 0];
if (norm.bias[3] !== undefined) {
normBias[3] = norm.bias[3];
}
}
}
const stride = height * width;
// Default pointer assignments
let rTensorPointer = 0, gTensorPointer = stride, bTensorPointer = stride * 2, aTensorPointer = -1;
// Updating the pointer assignments based on the input image format
if (inputformat === 'RGBA') {
rTensorPointer = 0;
gTensorPointer = stride;
bTensorPointer = stride * 2;
aTensorPointer = stride * 3;
}
else if (inputformat === 'RGB') {
rTensorPointer = 0;
gTensorPointer = stride;
bTensorPointer = stride * 2;
}
else if (inputformat === 'RBG') {
rTensorPointer = 0;
bTensorPointer = stride;
gTensorPointer = stride * 2;
}
for (let i = 0; i < height; i++) {
for (let j = 0; j < width; j++) {
const R = (tensor.data[rTensorPointer++] - normBias[0]) * normMean[0]; // R value
const G = (tensor.data[gTensorPointer++] - normBias[1]) * normMean[1]; // G value
const B = (tensor.data[bTensorPointer++] - normBias[2]) * normMean[2]; // B value
const A = aTensorPointer === -1 ? 255 : (tensor.data[aTensorPointer++] - normBias[3]) * normMean[3]; // A value
// eslint-disable-next-line @typescript-eslint/restrict-plus-operands
pixels2DContext.fillStyle = 'rgba(' + R + ',' + G + ',' + B + ',' + A + ')';
pixels2DContext.fillRect(j, i, 1, 1);
}
}
if ('toDataURL' in canvas) {
return canvas.toDataURL();
}
else {
throw new Error('toDataURL is not supported');
}
}
else {
throw new Error('Can not access image data');
}
};
/**
* implementation of Tensor.toImageData()
*/
const tensorToImageData = (tensor, options) => {
const pixels2DContext = typeof document !== 'undefined'
? document.createElement('canvas').getContext('2d')
: new OffscreenCanvas(1, 1).getContext('2d');
let image;
if (pixels2DContext != null) {
// Default values for height and width & format
let width;
let height;
let channels;
if (options?.tensorLayout !== undefined && options.tensorLayout === 'NHWC') {
width = tensor.dims[2];
height = tensor.dims[1];
channels = tensor.dims[3];
}
else {
// Default layout is NCWH
width = tensor.dims[3];
height = tensor.dims[2];
channels = tensor.dims[1];
}
const inputformat = options !== undefined ? (options.format !== undefined ? options.format : 'RGB') : 'RGB';
const norm = options?.norm;
let normMean;
let normBias;
if (norm === undefined || norm.mean === undefined) {
normMean = [255, 255, 255, 255];
}
else {
if (typeof norm.mean === 'number') {
normMean = [norm.mean, norm.mean, norm.mean, norm.mean];
}
else {
normMean = [norm.mean[0], norm.mean[1], norm.mean[2], 255];
if (norm.mean[3] !== undefined) {
normMean[3] = norm.mean[3];
}
}
}
if (norm === undefined || norm.bias === undefined) {
normBias = [0, 0, 0, 0];
}
else {
if (typeof norm.bias === 'number') {
normBias = [norm.bias, norm.bias, norm.bias, norm.bias];
}
else {
normBias = [norm.bias[0], norm.bias[1], norm.bias[2], 0];
if (norm.bias[3] !== undefined) {
normBias[3] = norm.bias[3];
}
}
}
const stride = height * width;
if (options !== undefined) {
if ((options.format !== undefined && channels === 4 && options.format !== 'RGBA') ||
(channels === 3 && options.format !== 'RGB' && options.format !== 'BGR')) {
throw new Error("Tensor format doesn't match input tensor dims");
}
}
// Default pointer assignments
const step = 4;
let rImagePointer = 0, gImagePointer = 1, bImagePointer = 2, aImagePointer = 3;
let rTensorPointer = 0, gTensorPointer = stride, bTensorPointer = stride * 2, aTensorPointer = -1;
// Updating the pointer assignments based on the input image format
if (inputformat === 'RGBA') {
rTensorPointer = 0;
gTensorPointer = stride;
bTensorPointer = stride * 2;
aTensorPointer = stride * 3;
}
else if (inputformat === 'RGB') {
rTensorPointer = 0;
gTensorPointer = stride;
bTensorPointer = stride * 2;
}
else if (inputformat === 'RBG') {
rTensorPointer = 0;
bTensorPointer = stride;
gTensorPointer = stride * 2;
}
image = pixels2DContext.createImageData(width, height);
for (let i = 0; i < height * width; rImagePointer += step, gImagePointer += step, bImagePointer += step, aImagePointer += step, i++) {
image.data[rImagePointer] = (tensor.data[rTensorPointer++] - normBias[0]) * normMean[0]; // R value
image.data[gImagePointer] = (tensor.data[gTensorPointer++] - normBias[1]) * normMean[1]; // G value
image.data[bImagePointer] = (tensor.data[bTensorPointer++] - normBias[2]) * normMean[2]; // B value
image.data[aImagePointer] =
aTensorPointer === -1 ? 255 : (tensor.data[aTensorPointer++] - normBias[3]) * normMean[3]; // A value
}
}
else {
throw new Error('Can not access image data');
}
return image;
};
//# sourceMappingURL=tensor-conversion-impl.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/tensor-conversion.js":
/*!***********************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/tensor-conversion.js ***!
\***********************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//# sourceMappingURL=tensor-conversion.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/tensor-factory-impl.js":
/*!*************************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/tensor-factory-impl.js ***!
\*************************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ bufferToTensor: () => (/* binding */ bufferToTensor),
/* harmony export */ tensorFromGpuBuffer: () => (/* binding */ tensorFromGpuBuffer),
/* harmony export */ tensorFromImage: () => (/* binding */ tensorFromImage),
/* harmony export */ tensorFromMLTensor: () => (/* binding */ tensorFromMLTensor),
/* harmony export */ tensorFromPinnedBuffer: () => (/* binding */ tensorFromPinnedBuffer),
/* harmony export */ tensorFromTexture: () => (/* binding */ tensorFromTexture)
/* harmony export */ });
/* harmony import */ var _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./tensor-impl.js */ "./node_modules/onnxruntime-common/dist/esm/tensor-impl.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
/**
* Create a new tensor object from image object
*
* @param buffer - Extracted image buffer data - assuming RGBA format
* @param imageFormat - input image configuration - required configurations height, width, format
* @param tensorFormat - output tensor configuration - Default is RGB format
*/
const bufferToTensor = (buffer, options) => {
if (buffer === undefined) {
throw new Error('Image buffer must be defined');
}
if (options.height === undefined || options.width === undefined) {
throw new Error('Image height and width must be defined');
}
if (options.tensorLayout === 'NHWC') {
throw new Error('NHWC Tensor layout is not supported yet');
}
const { height, width } = options;
const norm = options.norm ?? { mean: 255, bias: 0 };
let normMean;
let normBias;
if (typeof norm.mean === 'number') {
normMean = [norm.mean, norm.mean, norm.mean, norm.mean];
}
else {
normMean = [norm.mean[0], norm.mean[1], norm.mean[2], norm.mean[3] ?? 255];
}
if (typeof norm.bias === 'number') {
normBias = [norm.bias, norm.bias, norm.bias, norm.bias];
}
else {
normBias = [norm.bias[0], norm.bias[1], norm.bias[2], norm.bias[3] ?? 0];
}
const inputformat = options.format !== undefined ? options.format : 'RGBA';
// default value is RGBA since imagedata and HTMLImageElement uses it
const outputformat = options.tensorFormat !== undefined ? (options.tensorFormat !== undefined ? options.tensorFormat : 'RGB') : 'RGB';
const stride = height * width;
const float32Data = outputformat === 'RGBA' ? new Float32Array(stride * 4) : new Float32Array(stride * 3);
// Default pointer assignments
let step = 4, rImagePointer = 0, gImagePointer = 1, bImagePointer = 2, aImagePointer = 3;
let rTensorPointer = 0, gTensorPointer = stride, bTensorPointer = stride * 2, aTensorPointer = -1;
// Updating the pointer assignments based on the input image format
if (inputformat === 'RGB') {
step = 3;
rImagePointer = 0;
gImagePointer = 1;
bImagePointer = 2;
aImagePointer = -1;
}
// Updating the pointer assignments based on the output tensor format
if (outputformat === 'RGBA') {
aTensorPointer = stride * 3;
}
else if (outputformat === 'RBG') {
rTensorPointer = 0;
bTensorPointer = stride;
gTensorPointer = stride * 2;
}
else if (outputformat === 'BGR') {
bTensorPointer = 0;
gTensorPointer = stride;
rTensorPointer = stride * 2;
}
for (let i = 0; i < stride; i++, rImagePointer += step, bImagePointer += step, gImagePointer += step, aImagePointer += step) {
float32Data[rTensorPointer++] = (buffer[rImagePointer] + normBias[0]) / normMean[0];
float32Data[gTensorPointer++] = (buffer[gImagePointer] + normBias[1]) / normMean[1];
float32Data[bTensorPointer++] = (buffer[bImagePointer] + normBias[2]) / normMean[2];
if (aTensorPointer !== -1 && aImagePointer !== -1) {
float32Data[aTensorPointer++] = (buffer[aImagePointer] + normBias[3]) / normMean[3];
}
}
// Float32Array -> ort.Tensor
const outputTensor = outputformat === 'RGBA'
? new _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor('float32', float32Data, [1, 4, height, width])
: new _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor('float32', float32Data, [1, 3, height, width]);
return outputTensor;
};
/**
* implementation of Tensor.fromImage().
*/
const tensorFromImage = async (image, options) => {
// checking the type of image object
const isHTMLImageEle = typeof HTMLImageElement !== 'undefined' && image instanceof HTMLImageElement;
const isImageDataEle = typeof ImageData !== 'undefined' && image instanceof ImageData;
const isImageBitmap = typeof ImageBitmap !== 'undefined' && image instanceof ImageBitmap;
const isString = typeof image === 'string';
let data;
let bufferToTensorOptions = options ?? {};
const createCanvas = () => {
if (typeof document !== 'undefined') {
return document.createElement('canvas');
}
else if (typeof OffscreenCanvas !== 'undefined') {
return new OffscreenCanvas(1, 1);
}
else {
throw new Error('Canvas is not supported');
}
};
const createCanvasContext = (canvas) => {
if (typeof HTMLCanvasElement !== 'undefined' && canvas instanceof HTMLCanvasElement) {
return canvas.getContext('2d');
}
else if (canvas instanceof OffscreenCanvas) {
return canvas.getContext('2d');
}
else {
return null;
}
};
// filling and checking image configuration options
if (isHTMLImageEle) {
// HTMLImageElement - image object - format is RGBA by default
const canvas = createCanvas();
canvas.width = image.width;
canvas.height = image.height;
const pixels2DContext = createCanvasContext(canvas);
if (pixels2DContext != null) {
let height = image.height;
let width = image.width;
if (options !== undefined && options.resizedHeight !== undefined && options.resizedWidth !== undefined) {
height = options.resizedHeight;
width = options.resizedWidth;
}
if (options !== undefined) {
bufferToTensorOptions = options;
if (options.tensorFormat !== undefined) {
throw new Error('Image input config format must be RGBA for HTMLImageElement');
}
else {
bufferToTensorOptions.tensorFormat = 'RGBA';
}
bufferToTensorOptions.height = height;
bufferToTensorOptions.width = width;
}
else {
bufferToTensorOptions.tensorFormat = 'RGBA';
bufferToTensorOptions.height = height;
bufferToTensorOptions.width = width;
}
pixels2DContext.drawImage(image, 0, 0);
data = pixels2DContext.getImageData(0, 0, width, height).data;
}
else {
throw new Error('Can not access image data');
}
}
else if (isImageDataEle) {
let height;
let width;
if (options !== undefined && options.resizedWidth !== undefined && options.resizedHeight !== undefined) {
height = options.resizedHeight;
width = options.resizedWidth;
}
else {
height = image.height;
width = image.width;
}
if (options !== undefined) {
bufferToTensorOptions = options;
}
bufferToTensorOptions.format = 'RGBA';
bufferToTensorOptions.height = height;
bufferToTensorOptions.width = width;
if (options !== undefined) {
const tempCanvas = createCanvas();
tempCanvas.width = width;
tempCanvas.height = height;
const pixels2DContext = createCanvasContext(tempCanvas);
if (pixels2DContext != null) {
pixels2DContext.putImageData(image, 0, 0);
data = pixels2DContext.getImageData(0, 0, width, height).data;
}
else {
throw new Error('Can not access image data');
}
}
else {
data = image.data;
}
}
else if (isImageBitmap) {
// ImageBitmap - image object - format must be provided by user
if (options === undefined) {
throw new Error('Please provide image config with format for Imagebitmap');
}
const canvas = createCanvas();
canvas.width = image.width;
canvas.height = image.height;
const pixels2DContext = createCanvasContext(canvas);
if (pixels2DContext != null) {
const height = image.height;
const width = image.width;
pixels2DContext.drawImage(image, 0, 0, width, height);
data = pixels2DContext.getImageData(0, 0, width, height).data;
bufferToTensorOptions.height = height;
bufferToTensorOptions.width = width;
return bufferToTensor(data, bufferToTensorOptions);
}
else {
throw new Error('Can not access image data');
}
}
else if (isString) {
return new Promise((resolve, reject) => {
const canvas = createCanvas();
const context = createCanvasContext(canvas);
if (!image || !context) {
return reject();
}
const newImage = new Image();
newImage.crossOrigin = 'Anonymous';
newImage.src = image;
newImage.onload = () => {
canvas.width = newImage.width;
canvas.height = newImage.height;
context.drawImage(newImage, 0, 0, canvas.width, canvas.height);
const img = context.getImageData(0, 0, canvas.width, canvas.height);
bufferToTensorOptions.height = canvas.height;
bufferToTensorOptions.width = canvas.width;
resolve(bufferToTensor(img.data, bufferToTensorOptions));
};
});
}
else {
throw new Error('Input data provided is not supported - aborted tensor creation');
}
if (data !== undefined) {
return bufferToTensor(data, bufferToTensorOptions);
}
else {
throw new Error('Input data provided is not supported - aborted tensor creation');
}
};
/**
* implementation of Tensor.fromTexture().
*/
const tensorFromTexture = (texture, options) => {
const { width, height, download, dispose } = options;
// Always assume RGBAF32. TODO: support different texture format
const dims = [1, height, width, 4];
return new _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor({ location: 'texture', type: 'float32', texture, dims, download, dispose });
};
/**
* implementation of Tensor.fromGpuBuffer().
*/
const tensorFromGpuBuffer = (gpuBuffer, options) => {
const { dataType, dims, download, dispose } = options;
return new _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor({ location: 'gpu-buffer', type: dataType ?? 'float32', gpuBuffer, dims, download, dispose });
};
/**
* implementation of Tensor.fromMLTensor().
*/
const tensorFromMLTensor = (mlTensor, options) => {
const { dataType, dims, download, dispose } = options;
return new _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor({ location: 'ml-tensor', type: dataType ?? 'float32', mlTensor, dims, download, dispose });
};
/**
* implementation of Tensor.fromPinnedBuffer().
*/
const tensorFromPinnedBuffer = (type, buffer, dims) => new _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor({ location: 'cpu-pinned', type, data: buffer, dims: dims ?? [buffer.length] });
//# sourceMappingURL=tensor-factory-impl.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/tensor-factory.js":
/*!********************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/tensor-factory.js ***!
\********************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//# sourceMappingURL=tensor-factory.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/tensor-impl-type-mapping.js":
/*!******************************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/tensor-impl-type-mapping.js ***!
\******************************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP: () => (/* binding */ NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP),
/* harmony export */ NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP: () => (/* binding */ NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP),
/* harmony export */ checkTypedArray: () => (/* binding */ checkTypedArray)
/* harmony export */ });
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap.
const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map([
['float32', Float32Array],
['uint8', Uint8Array],
['int8', Int8Array],
['uint16', Uint16Array],
['int16', Int16Array],
['int32', Int32Array],
['bool', Uint8Array],
['float64', Float64Array],
['uint32', Uint32Array],
['int4', Uint8Array],
['uint4', Uint8Array],
]);
// a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap.
const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map([
[Float32Array, 'float32'],
[Uint8Array, 'uint8'],
[Int8Array, 'int8'],
[Uint16Array, 'uint16'],
[Int16Array, 'int16'],
[Int32Array, 'int32'],
[Float64Array, 'float64'],
[Uint32Array, 'uint32'],
]);
// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for
// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array
// polyfill if available.
let isTypedArrayChecked = false;
const checkTypedArray = () => {
if (!isTypedArrayChecked) {
isTypedArrayChecked = true;
const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from;
const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from;
const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from;
if (isBigInt64ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigInt64Array, 'int64');
}
if (isBigUint64ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64');
}
if (isFloat16ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16');
}
else {
// if Float16Array is not available, use 'Uint16Array' to store the data.
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array);
}
}
};
//# sourceMappingURL=tensor-impl-type-mapping.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/tensor-impl.js":
/*!*****************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/tensor-impl.js ***!
\*****************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ Tensor: () => (/* binding */ Tensor)
/* harmony export */ });
/* harmony import */ var _tensor_conversion_impl_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./tensor-conversion-impl.js */ "./node_modules/onnxruntime-common/dist/esm/tensor-conversion-impl.js");
/* harmony import */ var _tensor_factory_impl_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ./tensor-factory-impl.js */ "./node_modules/onnxruntime-common/dist/esm/tensor-factory-impl.js");
/* harmony import */ var _tensor_impl_type_mapping_js__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! ./tensor-impl-type-mapping.js */ "./node_modules/onnxruntime-common/dist/esm/tensor-impl-type-mapping.js");
/* harmony import */ var _tensor_utils_impl_js__WEBPACK_IMPORTED_MODULE_3__ = __webpack_require__(/*! ./tensor-utils-impl.js */ "./node_modules/onnxruntime-common/dist/esm/tensor-utils-impl.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
/**
* the implementation of Tensor interface.
*
* @ignore
*/
class Tensor {
/**
* implementation.
*/
constructor(arg0, arg1, arg2) {
// perform one-time check for BigInt/Float16Array support
(0,_tensor_impl_type_mapping_js__WEBPACK_IMPORTED_MODULE_2__.checkTypedArray)();
let type;
let dims;
if (typeof arg0 === 'object' && 'location' in arg0) {
//
// constructing tensor from specific location
//
this.dataLocation = arg0.location;
type = arg0.type;
dims = arg0.dims;
switch (arg0.location) {
case 'cpu-pinned': {
const expectedTypedArrayConstructor = _tensor_impl_type_mapping_js__WEBPACK_IMPORTED_MODULE_2__.NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.get(type);
if (!expectedTypedArrayConstructor) {
throw new TypeError(`unsupported type "${type}" to create tensor from pinned buffer`);
}
if (!(arg0.data instanceof expectedTypedArrayConstructor)) {
throw new TypeError(`buffer should be of type ${expectedTypedArrayConstructor.name}`);
}
this.cpuData = arg0.data;
break;
}
case 'texture': {
if (type !== 'float32') {
throw new TypeError(`unsupported type "${type}" to create tensor from texture`);
}
this.gpuTextureData = arg0.texture;
this.downloader = arg0.download;
this.disposer = arg0.dispose;
break;
}
case 'gpu-buffer': {
if (type !== 'float32' &&
type !== 'float16' &&
type !== 'int32' &&
type !== 'int64' &&
type !== 'uint32' &&
type !== 'uint8' &&
type !== 'bool' &&
type !== 'uint4' &&
type !== 'int4') {
throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`);
}
this.gpuBufferData = arg0.gpuBuffer;
this.downloader = arg0.download;
this.disposer = arg0.dispose;
break;
}
case 'ml-tensor': {
if (type !== 'float32' &&
type !== 'float16' &&
type !== 'int32' &&
type !== 'int64' &&
type !== 'uint32' &&
type !== 'uint64' &&
type !== 'int8' &&
type !== 'uint8' &&
type !== 'bool') {
throw new TypeError(`unsupported type "${type}" to create tensor from MLTensor`);
}
this.mlTensorData = arg0.mlTensor;
this.downloader = arg0.download;
this.disposer = arg0.dispose;
break;
}
default:
throw new Error(`Tensor constructor: unsupported location '${this.dataLocation}'`);
}
}
else {
//
// constructing tensor of location 'cpu'
//
let data;
let maybeDims;
// check whether arg0 is type or data
if (typeof arg0 === 'string') {
//
// Override: constructor(type, data, ...)
//
type = arg0;
maybeDims = arg2;
if (arg0 === 'string') {
// string tensor
if (!Array.isArray(arg1)) {
throw new TypeError("A string tensor's data must be a string array.");
}
// we don't check whether every element in the array is string; this is too slow. we assume it's correct and
// error will be populated at inference
data = arg1;
}
else {
// numeric tensor
const typedArrayConstructor = _tensor_impl_type_mapping_js__WEBPACK_IMPORTED_MODULE_2__.NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.get(arg0);
if (typedArrayConstructor === undefined) {
throw new TypeError(`Unsupported tensor type: ${arg0}.`);
}
if (Array.isArray(arg1)) {
if ((arg0 === 'float16' && typedArrayConstructor === Uint16Array) || arg0 === 'uint4' || arg0 === 'int4') {
// - 'float16':
// When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array.
//
// Throw error here because when user try to use number array as data,
// e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
// Uint16Array.from(arg1) which generates wrong data.
//
// - 'uint4' and 'int4':
// Uint8Array.from(arg1) will generate wrong data for 'uint4' and 'int4' tensor.
//
throw new TypeError(`Creating a ${arg0} tensor from number array is not supported. Please use ${typedArrayConstructor.name} as data.`);
}
else if (arg0 === 'uint64' || arg0 === 'int64') {
// use 'as any' here because:
// 1. TypeScript's check on type of 'Array.isArray()' does not work with readonly arrays.
// see https://github.com/microsoft/TypeScript/issues/17002
// 2. TypeScript's check on union type of '(BigInt64ArrayConstructor|BigUint64ArrayConstructor).from()'
// does not accept parameter mapFn.
// 3. parameters of 'SupportedTypedArrayConstructors.from()' does not match the requirement of the union
// type.
// assume 'arg1' is of type "readonly number[]|readonly bigint[]" here.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
data = typedArrayConstructor.from(arg1, BigInt);
}
else {
// assume 'arg1' is of type "readonly number[]" here.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
data = typedArrayConstructor.from(arg1);
}
}
else if (arg1 instanceof typedArrayConstructor) {
data = arg1;
}
else if (arg1 instanceof Uint8ClampedArray) {
if (arg0 === 'uint8') {
data = Uint8Array.from(arg1);
}
else {
throw new TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`);
}
}
else {
throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`);
}
}
}
else {
//
// Override: constructor(data, ...)
//
maybeDims = arg1;
if (Array.isArray(arg0)) {
// only boolean[] and string[] is supported
if (arg0.length === 0) {
throw new TypeError('Tensor type cannot be inferred from an empty array.');
}
const firstElementType = typeof arg0[0];
if (firstElementType === 'string') {
type = 'string';
data = arg0;
}
else if (firstElementType === 'boolean') {
type = 'bool';
// 'arg0' is of type 'boolean[]'. Uint8Array.from(boolean[]) actually works, but typescript thinks this is
// wrong type. We use 'as any' to make it happy.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
data = Uint8Array.from(arg0);
}
else {
throw new TypeError(`Invalid element type of data array: ${firstElementType}.`);
}
}
else if (arg0 instanceof Uint8ClampedArray) {
type = 'uint8';
data = Uint8Array.from(arg0);
}
else {
// get tensor type from TypedArray
const mappedType = _tensor_impl_type_mapping_js__WEBPACK_IMPORTED_MODULE_2__.NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get(arg0.constructor);
if (mappedType === undefined) {
throw new TypeError(`Unsupported type for tensor data: ${arg0.constructor}.`);
}
type = mappedType;
data = arg0;
}
}
// type and data is processed, now processing dims
if (maybeDims === undefined) {
// assume 1-D tensor if dims omitted
maybeDims = [data.length];
}
else if (!Array.isArray(maybeDims)) {
throw new TypeError("A tensor's dims must be a number array");
}
dims = maybeDims;
this.cpuData = data;
this.dataLocation = 'cpu';
}
// perform check on dims
const size = (0,_tensor_utils_impl_js__WEBPACK_IMPORTED_MODULE_3__.calculateSize)(dims);
// if data is on CPU, check whether data length matches tensor size
if (this.cpuData && size !== this.cpuData.length) {
if ((type === 'uint4' || type === 'int4') && Math.ceil(size / 2) === this.cpuData.length) {
// for (u)int4, the data length is half of the tensor size. So we check this special case when size is odd.
}
else {
throw new Error(`Tensor's size(${size}) does not match data length(${this.cpuData.length}).`);
}
}
this.type = type;
this.dims = dims;
this.size = size;
}
// #endregion
// #region factory
static async fromImage(image, options) {
return (0,_tensor_factory_impl_js__WEBPACK_IMPORTED_MODULE_1__.tensorFromImage)(image, options);
}
static fromTexture(texture, options) {
return (0,_tensor_factory_impl_js__WEBPACK_IMPORTED_MODULE_1__.tensorFromTexture)(texture, options);
}
static fromGpuBuffer(gpuBuffer, options) {
return (0,_tensor_factory_impl_js__WEBPACK_IMPORTED_MODULE_1__.tensorFromGpuBuffer)(gpuBuffer, options);
}
static fromMLTensor(mlTensor, options) {
return (0,_tensor_factory_impl_js__WEBPACK_IMPORTED_MODULE_1__.tensorFromMLTensor)(mlTensor, options);
}
static fromPinnedBuffer(type, buffer, dims) {
return (0,_tensor_factory_impl_js__WEBPACK_IMPORTED_MODULE_1__.tensorFromPinnedBuffer)(type, buffer, dims);
}
// #endregion
// #region conversions
toDataURL(options) {
return (0,_tensor_conversion_impl_js__WEBPACK_IMPORTED_MODULE_0__.tensorToDataURL)(this, options);
}
toImageData(options) {
return (0,_tensor_conversion_impl_js__WEBPACK_IMPORTED_MODULE_0__.tensorToImageData)(this, options);
}
// #endregion
// #region properties
get data() {
this.ensureValid();
if (!this.cpuData) {
throw new Error('The data is not on CPU. Use `getData()` to download GPU data to CPU, ' +
'or use `texture` or `gpuBuffer` property to access the GPU data directly.');
}
return this.cpuData;
}
get location() {
return this.dataLocation;
}
get texture() {
this.ensureValid();
if (!this.gpuTextureData) {
throw new Error('The data is not stored as a WebGL texture.');
}
return this.gpuTextureData;
}
get gpuBuffer() {
this.ensureValid();
if (!this.gpuBufferData) {
throw new Error('The data is not stored as a WebGPU buffer.');
}
return this.gpuBufferData;
}
get mlTensor() {
this.ensureValid();
if (!this.mlTensorData) {
throw new Error('The data is not stored as a WebNN MLTensor.');
}
return this.mlTensorData;
}
// #endregion
// #region methods
async getData(releaseData) {
this.ensureValid();
switch (this.dataLocation) {
case 'cpu':
case 'cpu-pinned':
return this.data;
case 'texture':
case 'gpu-buffer':
case 'ml-tensor': {
if (!this.downloader) {
throw new Error('The current tensor is not created with a specified data downloader.');
}
if (this.isDownloading) {
throw new Error('The current tensor is being downloaded.');
}
try {
this.isDownloading = true;
const data = await this.downloader();
this.downloader = undefined;
this.dataLocation = 'cpu';
this.cpuData = data;
if (releaseData && this.disposer) {
this.disposer();
this.disposer = undefined;
}
return data;
}
finally {
this.isDownloading = false;
}
}
default:
throw new Error(`cannot get data from location: ${this.dataLocation}`);
}
}
dispose() {
if (this.isDownloading) {
throw new Error('The current tensor is being downloaded.');
}
if (this.disposer) {
this.disposer();
this.disposer = undefined;
}
this.cpuData = undefined;
this.gpuTextureData = undefined;
this.gpuBufferData = undefined;
this.mlTensorData = undefined;
this.downloader = undefined;
this.isDownloading = undefined;
this.dataLocation = 'none';
}
// #endregion
// #region tensor utilities
ensureValid() {
if (this.dataLocation === 'none') {
throw new Error('The tensor is disposed.');
}
}
reshape(dims) {
this.ensureValid();
if (this.downloader || this.disposer) {
throw new Error('Cannot reshape a tensor that owns GPU resource.');
}
return (0,_tensor_utils_impl_js__WEBPACK_IMPORTED_MODULE_3__.tensorReshape)(this, dims);
}
}
//# sourceMappingURL=tensor-impl.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/tensor-utils-impl.js":
/*!***********************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/tensor-utils-impl.js ***!
\***********************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ calculateSize: () => (/* binding */ calculateSize),
/* harmony export */ tensorReshape: () => (/* binding */ tensorReshape)
/* harmony export */ });
/* harmony import */ var _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./tensor-impl.js */ "./node_modules/onnxruntime-common/dist/esm/tensor-impl.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
/**
* calculate size from dims.
*
* @param dims the dims array. May be an illegal input.
*/
const calculateSize = (dims) => {
let size = 1;
for (let i = 0; i < dims.length; i++) {
const dim = dims[i];
if (typeof dim !== 'number' || !Number.isSafeInteger(dim)) {
throw new TypeError(`dims[${i}] must be an integer, got: ${dim}`);
}
if (dim < 0) {
throw new RangeError(`dims[${i}] must be a non-negative integer, got: ${dim}`);
}
size *= dim;
}
return size;
};
/**
* implementation of Tensor.reshape()
*/
const tensorReshape = (tensor, dims) => {
switch (tensor.location) {
case 'cpu':
return new _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor(tensor.type, tensor.data, dims);
case 'cpu-pinned':
return new _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor({
location: 'cpu-pinned',
data: tensor.data,
type: tensor.type,
dims,
});
case 'texture':
return new _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor({
location: 'texture',
texture: tensor.texture,
type: tensor.type,
dims,
});
case 'gpu-buffer':
return new _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor({
location: 'gpu-buffer',
gpuBuffer: tensor.gpuBuffer,
type: tensor.type,
dims,
});
case 'ml-tensor':
return new _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor({
location: 'ml-tensor',
mlTensor: tensor.mlTensor,
type: tensor.type,
dims,
});
default:
throw new Error(`tensorReshape: tensor location ${tensor.location} is not supported`);
}
};
//# sourceMappingURL=tensor-utils-impl.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/tensor.js":
/*!************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/tensor.js ***!
\************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ Tensor: () => (/* binding */ Tensor)
/* harmony export */ });
/* harmony import */ var _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./tensor-impl.js */ "./node_modules/onnxruntime-common/dist/esm/tensor-impl.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// eslint-disable-next-line @typescript-eslint/naming-convention
const Tensor = _tensor_impl_js__WEBPACK_IMPORTED_MODULE_0__.Tensor;
//# sourceMappingURL=tensor.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/trace.js":
/*!***********************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/trace.js ***!
\***********************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ TRACE: () => (/* binding */ TRACE),
/* harmony export */ TRACE_FUNC_BEGIN: () => (/* binding */ TRACE_FUNC_BEGIN),
/* harmony export */ TRACE_FUNC_END: () => (/* binding */ TRACE_FUNC_END)
/* harmony export */ });
/* harmony import */ var _env_impl_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./env-impl.js */ "./node_modules/onnxruntime-common/dist/esm/env-impl.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
/**
* @ignore
*/
const TRACE = (deviceType, label) => {
if (typeof _env_impl_js__WEBPACK_IMPORTED_MODULE_0__.env.trace === 'undefined' ? !_env_impl_js__WEBPACK_IMPORTED_MODULE_0__.env.wasm.trace : !_env_impl_js__WEBPACK_IMPORTED_MODULE_0__.env.trace) {
return;
}
// eslint-disable-next-line no-console
console.timeStamp(`${deviceType}::ORT::${label}`);
};
const TRACE_FUNC = (msg, extraMsg) => {
const stack = new Error().stack?.split(/\r\n|\r|\n/g) || [];
let hasTraceFunc = false;
for (let i = 0; i < stack.length; i++) {
if (hasTraceFunc && !stack[i].includes('TRACE_FUNC')) {
let label = `FUNC_${msg}::${stack[i].trim().split(' ')[1]}`;
if (extraMsg) {
label += `::${extraMsg}`;
}
TRACE('CPU', label);
return;
}
if (stack[i].includes('TRACE_FUNC')) {
hasTraceFunc = true;
}
}
};
/**
* @ignore
*/
const TRACE_FUNC_BEGIN = (extraMsg) => {
if (typeof _env_impl_js__WEBPACK_IMPORTED_MODULE_0__.env.trace === 'undefined' ? !_env_impl_js__WEBPACK_IMPORTED_MODULE_0__.env.wasm.trace : !_env_impl_js__WEBPACK_IMPORTED_MODULE_0__.env.trace) {
return;
}
TRACE_FUNC('BEGIN', extraMsg);
};
/**
* @ignore
*/
const TRACE_FUNC_END = (extraMsg) => {
if (typeof _env_impl_js__WEBPACK_IMPORTED_MODULE_0__.env.trace === 'undefined' ? !_env_impl_js__WEBPACK_IMPORTED_MODULE_0__.env.wasm.trace : !_env_impl_js__WEBPACK_IMPORTED_MODULE_0__.env.trace) {
return;
}
TRACE_FUNC('END', extraMsg);
};
//# sourceMappingURL=trace.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/training-session-impl.js":
/*!***************************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/training-session-impl.js ***!
\***************************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ TrainingSession: () => (/* binding */ TrainingSession)
/* harmony export */ });
/* harmony import */ var _backend_impl_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./backend-impl.js */ "./node_modules/onnxruntime-common/dist/esm/backend-impl.js");
/* harmony import */ var _tensor_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ./tensor.js */ "./node_modules/onnxruntime-common/dist/esm/tensor.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
const noBackendErrMsg = 'Training backend could not be resolved. ' + "Make sure you're using the correct configuration & WebAssembly files.";
class TrainingSession {
constructor(handler, hasOptimizerModel, hasEvalModel) {
this.handler = handler;
this.hasOptimizerModel = hasOptimizerModel;
this.hasEvalModel = hasEvalModel;
}
get trainingInputNames() {
return this.handler.inputNames;
}
get trainingOutputNames() {
return this.handler.outputNames;
}
get evalInputNames() {
if (this.hasEvalModel) {
return this.handler.evalInputNames;
}
else {
throw new Error('This training session has no evalModel loaded.');
}
}
get evalOutputNames() {
if (this.hasEvalModel) {
return this.handler.evalOutputNames;
}
else {
throw new Error('This training session has no evalModel loaded.');
}
}
static async create(trainingOptions, sessionOptions) {
const evalModel = trainingOptions.evalModel || '';
const optimizerModel = trainingOptions.optimizerModel || '';
const options = sessionOptions || {};
// resolve backend, update session options with validated EPs, and create session handler
const [backend, optionsWithValidatedEPs] = await (0,_backend_impl_js__WEBPACK_IMPORTED_MODULE_0__.resolveBackendAndExecutionProviders)(options);
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, optionsWithValidatedEPs);
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
}
else {
throw new Error(noBackendErrMsg);
}
}
/**
* Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
* the given parameters to SessionHandler.FetchesType and RunOptions.
*
* @param inputNames the feeds object is checked that they contain all input names in the provided list of input
* names.
* @param outputNames the fetches object is checked that their keys match up with valid names in the list of output
* names.
* @param feeds the required input
* @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
* @param arg2 optional RunOptions object.
* @returns
*/
typeNarrowingForRunStep(inputNames, outputNames, feeds, arg1, arg2) {
const fetches = {};
let options = {};
// check inputs
if (typeof feeds !== 'object' || feeds === null || feeds instanceof _tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor || Array.isArray(feeds)) {
throw new TypeError("'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.");
}
let isFetchesEmpty = true;
// determine which override is being used
if (typeof arg1 === 'object') {
if (arg1 === null) {
throw new TypeError('Unexpected argument[1]: cannot be null.');
}
if (arg1 instanceof _tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor) {
throw new TypeError("'fetches' cannot be a Tensor");
}
if (Array.isArray(arg1)) {
if (arg1.length === 0) {
throw new TypeError("'fetches' cannot be an empty array.");
}
isFetchesEmpty = false;
// output names
for (const name of arg1) {
if (typeof name !== 'string') {
throw new TypeError("'fetches' must be a string array or an object.");
}
if (outputNames.indexOf(name) === -1) {
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
}
fetches[name] = null;
}
if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
}
else if (typeof arg2 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
}
else {
// decide whether arg1 is fetches or options
// if any output name is present and its value is valid OnnxValue, we consider it fetches
let isFetches = false;
const arg1Keys = Object.getOwnPropertyNames(arg1);
for (const name of outputNames) {
if (arg1Keys.indexOf(name) !== -1) {
const v = arg1[name];
if (v === null || v instanceof _tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor) {
isFetches = true;
isFetchesEmpty = false;
fetches[name] = v;
}
}
}
if (isFetches) {
if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
}
else if (typeof arg2 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
}
else {
options = arg1;
}
}
}
else if (typeof arg1 !== 'undefined') {
throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'.");
}
// check if all inputs are in feed
for (const name of inputNames) {
if (typeof feeds[name] === 'undefined') {
throw new Error(`input '${name}' is missing in 'feeds'.`);
}
}
// if no fetches is specified, we use the full output names list
if (isFetchesEmpty) {
for (const name of outputNames) {
fetches[name] = null;
}
}
return [fetches, options];
}
/**
* Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler
* and changes it into a map of Tensors.
*
* @param results
* @returns
*/
convertHandlerReturnTypeToMapOfTensors(results) {
const returnValue = {};
for (const key in results) {
if (Object.hasOwnProperty.call(results, key)) {
const result = results[key];
if (result instanceof _tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor) {
returnValue[key] = result;
}
else {
returnValue[key] = new _tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor(result.type, result.data, result.dims);
}
}
}
return returnValue;
}
async lazyResetGrad() {
await this.handler.lazyResetGrad();
}
async runTrainStep(feeds, arg1, arg2) {
const [fetches, options] = this.typeNarrowingForRunStep(this.trainingInputNames, this.trainingOutputNames, feeds, arg1, arg2);
const results = await this.handler.runTrainStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
}
async runOptimizerStep(options) {
if (this.hasOptimizerModel) {
await this.handler.runOptimizerStep(options || {});
}
else {
throw new Error('This TrainingSession has no OptimizerModel loaded.');
}
}
async runEvalStep(feeds, arg1, arg2) {
if (this.hasEvalModel) {
const [fetches, options] = this.typeNarrowingForRunStep(this.evalInputNames, this.evalOutputNames, feeds, arg1, arg2);
const results = await this.handler.runEvalStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
}
else {
throw new Error('This TrainingSession has no EvalModel loaded.');
}
}
async getParametersSize(trainableOnly = true) {
return this.handler.getParametersSize(trainableOnly);
}
async loadParametersBuffer(array, trainableOnly = true) {
const paramsSize = await this.getParametersSize(trainableOnly);
// checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number
// of parameters
if (array.length !== 4 * paramsSize) {
throw new Error('Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' +
'the model. Please use getParametersSize method to check.');
}
return this.handler.loadParametersBuffer(array, trainableOnly);
}
async getContiguousParameters(trainableOnly = true) {
return this.handler.getContiguousParameters(trainableOnly);
}
async release() {
return this.handler.dispose();
}
}
//# sourceMappingURL=training-session-impl.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/training-session.js":
/*!**********************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/training-session.js ***!
\**********************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ TrainingSession: () => (/* binding */ TrainingSession)
/* harmony export */ });
/* harmony import */ var _training_session_impl_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./training-session-impl.js */ "./node_modules/onnxruntime-common/dist/esm/training-session-impl.js");
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// eslint-disable-next-line @typescript-eslint/naming-convention
const TrainingSession = _training_session_impl_js__WEBPACK_IMPORTED_MODULE_0__.TrainingSession;
//# sourceMappingURL=training-session.js.map
/***/ }),
/***/ "./node_modules/onnxruntime-common/dist/esm/version.js":
/*!*************************************************************!*\
!*** ./node_modules/onnxruntime-common/dist/esm/version.js ***!
\*************************************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ version: () => (/* binding */ version)
/* harmony export */ });
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
const version = '1.20.1';
//# sourceMappingURL=version.js.map
/***/ }),
/***/ "./src/backends/onnx.js":
/*!******************************!*\
!*** ./src/backends/onnx.js ***!
\******************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ Tensor: () => (/* reexport safe */ onnxruntime_common__WEBPACK_IMPORTED_MODULE_2__.Tensor),
/* harmony export */ createInferenceSession: () => (/* binding */ createInferenceSession),
/* harmony export */ deviceToExecutionProviders: () => (/* binding */ deviceToExecutionProviders),
/* harmony export */ isONNXProxy: () => (/* binding */ isONNXProxy),
/* harmony export */ isONNXTensor: () => (/* binding */ isONNXTensor)
/* harmony export */ });
/* harmony import */ var _env_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ../env.js */ "./src/env.js");
/* harmony import */ var _onnxruntime_webgpu__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! #onnxruntime-webgpu */ "#onnxruntime-webgpu");
/* harmony import */ var onnxruntime_common__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! onnxruntime-common */ "./node_modules/onnxruntime-common/dist/esm/index.js");
/**
* @file Handler file for choosing the correct version of ONNX Runtime, based on the environment.
* Ideally, we could import the `onnxruntime-web` and `onnxruntime-node` packages only when needed,
* but dynamic imports don't seem to work with the current webpack version and/or configuration.
* This is possibly due to the experimental nature of top-level await statements.
* So, we just import both packages, and use the appropriate one based on the environment:
* - When running in node, we use `onnxruntime-node`.
* - When running in the browser, we use `onnxruntime-web` (`onnxruntime-node` is not bundled).
*
* This module is not directly exported, but can be accessed through the environment variables:
* ```javascript
* import { env } from '@huggingface/transformers';
* console.log(env.backends.onnx);
* ```
*
* @module backends/onnx
*/
const ONNX_NODE = null;
/**
* @typedef {import('onnxruntime-common').InferenceSession.ExecutionProviderConfig} ONNXExecutionProviders
*/
/** @type {Record<import("../utils/devices.js").DeviceType, ONNXExecutionProviders>} */
const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({
auto: null, // Auto-detect based on device and environment
gpu: null, // Auto-detect GPU
cpu: 'cpu', // CPU
wasm: 'wasm', // WebAssembly
webgpu: 'webgpu', // WebGPU
cuda: 'cuda', // CUDA
dml: 'dml', // DirectML
webnn: { name: 'webnn', deviceType: 'cpu' }, // WebNN (default)
'webnn-npu': { name: 'webnn', deviceType: 'npu' }, // WebNN NPU
'webnn-gpu': { name: 'webnn', deviceType: 'gpu' }, // WebNN GPU
'webnn-cpu': { name: 'webnn', deviceType: 'cpu' }, // WebNN CPU
});
/**
* The list of supported devices, sorted by priority/performance.
* @type {import("../utils/devices.js").DeviceType[]}
*/
const supportedDevices = [];
/** @type {ONNXExecutionProviders[]} */
let defaultDevices;
let ONNX;
const ORT_SYMBOL = Symbol.for('onnxruntime');
if (ORT_SYMBOL in globalThis) {
// If the JS runtime exposes their own ONNX runtime, use it
ONNX = globalThis[ORT_SYMBOL];
} else if (_env_js__WEBPACK_IMPORTED_MODULE_0__.apis.IS_NODE_ENV) {
ONNX = ONNX_NODE.default ?? ONNX_NODE;
// Updated as of ONNX Runtime 1.18.0
// The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries.
// | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 |
// | ------------- | ----------- | ------------- | ----------------- | ----------- | --------- | ----------- |
// | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
// | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ |
// | CUDA | ❌ | ❌ | ✔️ (CUDA v11.8) | ❌ | ❌ | ❌ |
switch (process.platform) {
case 'win32': // Windows x64 and Windows arm64
supportedDevices.push('dml');
break;
case 'linux': // Linux x64 and Linux arm64
if (process.arch === 'x64') {
supportedDevices.push('cuda');
}
break;
case 'darwin': // MacOS x64 and MacOS arm64
break;
}
supportedDevices.push('cpu');
defaultDevices = ['cpu'];
} else {
ONNX = _onnxruntime_webgpu__WEBPACK_IMPORTED_MODULE_1__;
if (_env_js__WEBPACK_IMPORTED_MODULE_0__.apis.IS_WEBNN_AVAILABLE) {
// TODO: Only push supported providers (depending on available hardware)
supportedDevices.push('webnn-npu', 'webnn-gpu', 'webnn-cpu', 'webnn');
}
if (_env_js__WEBPACK_IMPORTED_MODULE_0__.apis.IS_WEBGPU_AVAILABLE) {
supportedDevices.push('webgpu');
}
supportedDevices.push('wasm');
defaultDevices = ['wasm'];
}
// @ts-ignore
const InferenceSession = ONNX.InferenceSession;
/**
* Map a device to the execution providers to use for the given device.
* @param {import("../utils/devices.js").DeviceType|"auto"|null} [device=null] (Optional) The device to run the inference on.
* @returns {ONNXExecutionProviders[]} The execution providers to use for the given device.
*/
function deviceToExecutionProviders(device = null) {
// Use the default execution providers if the user hasn't specified anything
if (!device) return defaultDevices;
// Handle overloaded cases
switch (device) {
case "auto":
return supportedDevices;
case "gpu":
return supportedDevices.filter(x =>
["webgpu", "cuda", "dml", "webnn-gpu"].includes(x),
);
}
if (supportedDevices.includes(device)) {
return [DEVICE_TO_EXECUTION_PROVIDER_MAPPING[device] ?? device];
}
throw new Error(`Unsupported device: "${device}". Should be one of: ${supportedDevices.join(', ')}.`)
}
/**
* To prevent multiple calls to `initWasm()`, we store the first call in a Promise
* that is resolved when the first InferenceSession is created. Subsequent calls
* will wait for this Promise to resolve before creating their own InferenceSession.
* @type {Promise<any>|null}
*/
let wasmInitPromise = null;
/**
* Create an ONNX inference session.
* @param {Uint8Array} buffer The ONNX model buffer.
* @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options.
* @param {Object} session_config ONNX inference session configuration.
* @returns {Promise<import('onnxruntime-common').InferenceSession & { config: Object}>} The ONNX inference session.
*/
async function createInferenceSession(buffer, session_options, session_config) {
if (wasmInitPromise) {
// A previous session has already initialized the WASM runtime
// so we wait for it to resolve before creating this new session.
await wasmInitPromise;
}
const sessionPromise = InferenceSession.create(buffer, session_options);
wasmInitPromise ??= sessionPromise;
const session = await sessionPromise;
session.config = session_config;
return session;
}
/**
* Check if an object is an ONNX tensor.
* @param {any} x The object to check
* @returns {boolean} Whether the object is an ONNX tensor.
*/
function isONNXTensor(x) {
return x instanceof ONNX.Tensor;
}
/** @type {import('onnxruntime-common').Env} */
// @ts-ignore
const ONNX_ENV = ONNX?.env;
if (ONNX_ENV?.wasm) {
// Initialize wasm backend with suitable default settings.
// (Optional) Set path to wasm files. This is needed when running in a web worker.
// https://onnxruntime.ai/docs/api/js/interfaces/Env.WebAssemblyFlags.html#wasmPaths
// We use remote wasm files by default to make it easier for newer users.
// In practice, users should probably self-host the necessary .wasm files.
ONNX_ENV.wasm.wasmPaths = `https://cdn.jsdelivr.net/npm/@huggingface/transformers@${_env_js__WEBPACK_IMPORTED_MODULE_0__.env.version}/dist/`;
// TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0
// https://github.com/microsoft/onnxruntime/pull/21534
// Users may wish to proxy the WASM backend to prevent the UI from freezing,
// However, this is not necessary when using WebGPU, so we default to false.
ONNX_ENV.wasm.proxy = false;
// https://developer.mozilla.org/en-US/docs/Web/API/crossOriginIsolated
if (typeof crossOriginIsolated === 'undefined' || !crossOriginIsolated) {
ONNX_ENV.wasm.numThreads = 1;
}
}
if (ONNX_ENV?.webgpu) {
ONNX_ENV.webgpu.powerPreference = 'high-performance';
}
/**
* Check if ONNX's WASM backend is being proxied.
* @returns {boolean} Whether ONNX's WASM backend is being proxied.
*/
function isONNXProxy() {
// TODO: Update this when allowing non-WASM backends.
return ONNX_ENV?.wasm?.proxy;
}
// Expose ONNX environment variables to `env.backends.onnx`
_env_js__WEBPACK_IMPORTED_MODULE_0__.env.backends.onnx = ONNX_ENV;
/***/ }),
/***/ "./src/base/feature_extraction_utils.js":
/*!**********************************************!*\
!*** ./src/base/feature_extraction_utils.js ***!
\**********************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ FeatureExtractor: () => (/* binding */ FeatureExtractor),
/* harmony export */ validate_audio_inputs: () => (/* binding */ validate_audio_inputs)
/* harmony export */ });
/* harmony import */ var _utils_constants_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ../utils/constants.js */ "./src/utils/constants.js");
/* harmony import */ var _utils_generic_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ../utils/generic.js */ "./src/utils/generic.js");
/* harmony import */ var _utils_hub_js__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! ../utils/hub.js */ "./src/utils/hub.js");
/**
* Base class for feature extractors.
*/
class FeatureExtractor extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_1__.Callable {
/**
* Constructs a new FeatureExtractor instance.
*
* @param {Object} config The configuration for the feature extractor.
*/
constructor(config) {
super();
this.config = config
}
/**
* Instantiate one of the processor classes of the library from a pretrained model.
*
* The processor class to instantiate is selected based on the `image_processor_type` (or `feature_extractor_type`; legacy)
* property of the config object (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible)
*
* @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either:
* - A string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co.
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing processor files, e.g., `./my_model_directory/`.
* @param {import('../utils/hub.js').PretrainedOptions} options Additional options for loading the processor.
*
* @returns {Promise<FeatureExtractor>} A new instance of the Processor class.
*/
static async from_pretrained(pretrained_model_name_or_path, options) {
const preprocessorConfig = await (0,_utils_hub_js__WEBPACK_IMPORTED_MODULE_2__.getModelJSON)(pretrained_model_name_or_path, _utils_constants_js__WEBPACK_IMPORTED_MODULE_0__.FEATURE_EXTRACTOR_NAME, true, options);
return new this(preprocessorConfig);
}
}
/**
* Helper function to validate audio inputs.
* @param {any} audio The audio data.
* @param {string} feature_extractor The name of the feature extractor.
* @private
*/
function validate_audio_inputs(audio, feature_extractor) {
if (!(audio instanceof Float32Array || audio instanceof Float64Array)) {
throw new Error(
`${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead. ` +
`If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.`
)
}
}
/***/ }),
/***/ "./src/base/image_processors_utils.js":
/*!********************************************!*\
!*** ./src/base/image_processors_utils.js ***!
\********************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ ImageProcessor: () => (/* binding */ ImageProcessor),
/* harmony export */ post_process_instance_segmentation: () => (/* binding */ post_process_instance_segmentation),
/* harmony export */ post_process_object_detection: () => (/* binding */ post_process_object_detection),
/* harmony export */ post_process_panoptic_segmentation: () => (/* binding */ post_process_panoptic_segmentation),
/* harmony export */ post_process_semantic_segmentation: () => (/* binding */ post_process_semantic_segmentation)
/* harmony export */ });
/* harmony import */ var _utils_generic_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ../utils/generic.js */ "./src/utils/generic.js");
/* harmony import */ var _utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ../utils/tensor.js */ "./src/utils/tensor.js");
/* harmony import */ var _utils_maths_js__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! ../utils/maths.js */ "./src/utils/maths.js");
/* harmony import */ var _utils_image_js__WEBPACK_IMPORTED_MODULE_3__ = __webpack_require__(/*! ../utils/image.js */ "./src/utils/image.js");
/* harmony import */ var _utils_core_js__WEBPACK_IMPORTED_MODULE_4__ = __webpack_require__(/*! ../utils/core.js */ "./src/utils/core.js");
/* harmony import */ var _utils_hub_js__WEBPACK_IMPORTED_MODULE_5__ = __webpack_require__(/*! ../utils/hub.js */ "./src/utils/hub.js");
/* harmony import */ var _utils_constants_js__WEBPACK_IMPORTED_MODULE_6__ = __webpack_require__(/*! ../utils/constants.js */ "./src/utils/constants.js");
/**
* Named tuple to indicate the order we are using is (height x width),
* even though the Graphics' industry standard is (width x height).
* @typedef {[height: number, width: number]} HeightWidth
*/
/**
* @typedef {object} ImageProcessorResult
* @property {Tensor} pixel_values The pixel values of the batched preprocessed images.
* @property {HeightWidth[]} original_sizes Array of two-dimensional tuples like [[480, 640]].
* @property {HeightWidth[]} reshaped_input_sizes Array of two-dimensional tuples like [[1000, 1330]].
*/
/**
* Helper function to constrain a value to be a multiple of a number.
* @param {number} val The value to constrain.
* @param {number} multiple The number to constrain to.
* @param {number} [minVal=0] The minimum value to constrain to.
* @param {number} [maxVal=null] The maximum value to constrain to.
* @returns {number} The constrained value.
* @private
*/
function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) {
const a = val / multiple;
let x = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.bankers_round)(a) * multiple;
if (maxVal !== null && x > maxVal) {
x = Math.floor(a) * multiple;
}
if (x < minVal) {
x = Math.ceil(a) * multiple;
}
return x;
}
/**
* Rounds the height and width down to the closest multiple of size_divisibility
* @param {[number, number]} size The size of the image
* @param {number} divisor The divisor to use.
* @returns {[number, number]} The rounded size.
*/
function enforce_size_divisibility([width, height], divisor) {
return [
Math.max(Math.floor(width / divisor), 1) * divisor,
Math.max(Math.floor(height / divisor), 1) * divisor
];
}
// Helper functions
/**
* Converts bounding boxes from center format to corners format.
*
* @param {number[]} arr The coordinate for the center of the box and its width, height dimensions (center_x, center_y, width, height)
* @returns {number[]} The coodinates for the top-left and bottom-right corners of the box (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
*/
function center_to_corners_format([centerX, centerY, width, height]) {
return [
centerX - width / 2,
centerY - height / 2,
centerX + width / 2,
centerY + height / 2
];
}
/**
* Post-processes the outputs of the model (for object detection).
* @param {Object} outputs The outputs of the model that must be post-processed
* @param {Tensor} outputs.logits The logits
* @param {Tensor} outputs.pred_boxes The predicted boxes.
* @param {number} [threshold=0.5] The threshold to use for the scores.
* @param {[number, number][]} [target_sizes=null] The sizes of the original images.
* @param {boolean} [is_zero_shot=false] Whether zero-shot object detection was performed.
* @return {Object[]} An array of objects containing the post-processed outputs.
*/
function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null, is_zero_shot = false) {
const out_logits = outputs.logits;
const out_bbox = outputs.pred_boxes;
const [batch_size, num_boxes, num_classes] = out_logits.dims;
if (target_sizes !== null && target_sizes.length !== batch_size) {
throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
}
let toReturn = [];
for (let i = 0; i < batch_size; ++i) {
let target_size = target_sizes !== null ? target_sizes[i] : null;
let info = {
boxes: [],
classes: [],
scores: []
}
let logits = out_logits[i];
let bbox = out_bbox[i];
for (let j = 0; j < num_boxes; ++j) {
let logit = logits[j];
let indices = [];
let probs;
if (is_zero_shot) {
// Get indices of classes with high enough probability
probs = logit.sigmoid().data;
for (let k = 0; k < probs.length; ++k) {
if (probs[k] > threshold) {
indices.push(k);
}
}
} else {
// Get most probable class
let maxIndex = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.max)(logit.data)[1];
if (maxIndex === num_classes - 1) {
// This is the background class, skip it
continue;
}
// Compute softmax over classes
probs = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.softmax)(logit.data);
if (probs[maxIndex] < threshold) {
continue;
}
indices.push(maxIndex);
}
for (const index of indices) {
// Some class has a high enough probability
/** @type {number[]} */
let box = bbox[j].data;
// convert to [x0, y0, x1, y1] format
box = center_to_corners_format(box)
if (target_size !== null) {
box = box.map((x, i) => x * target_size[(i + 1) % 2])
}
info.boxes.push(box);
info.classes.push(index);
info.scores.push(probs[index]);
}
}
toReturn.push(info);
}
return toReturn;
}
/**
* Post-processes the outputs of the model (for semantic segmentation).
* @param {*} outputs Raw outputs of the model.
* @param {[number, number][]} [target_sizes=null] List of tuples corresponding to the requested final size
* (height, width) of each prediction. If unset, predictions will not be resized.
* @returns {{segmentation: Tensor; labels: number[]}[]} The semantic segmentation maps.
*/
function post_process_semantic_segmentation(outputs, target_sizes = null) {
const logits = outputs.logits;
const batch_size = logits.dims[0];
if (target_sizes !== null && target_sizes.length !== batch_size) {
throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
}
const toReturn = [];
for (let i = 0; i < batch_size; ++i) {
const target_size = target_sizes !== null ? target_sizes[i] : null;
let data = logits[i];
// 1. If target_size is not null, we need to resize the masks to the target size
if (target_size !== null) {
// resize the masks to the target size
data = (0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__.interpolate)(data, target_size, 'bilinear', false);
}
const [height, width] = target_size ?? data.dims.slice(-2);
const segmentation = new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor(
'int32',
new Int32Array(height * width),
[height, width]
);
// Buffer to store current largest value
const buffer = data[0].data;
const segmentation_data = segmentation.data;
for (let j = 1; j < data.dims[0]; ++j) {
const row = data[j].data;
for (let k = 0; k < row.length; ++k) {
if (row[k] > buffer[k]) {
buffer[k] = row[k];
segmentation_data[k] = j;
}
}
}
// Store which objects have labels
// This is much more efficient that creating a set of the final values
const hasLabel = new Array(data.dims[0]);
for (let j = 0; j < segmentation_data.length; ++j) {
const index = segmentation_data[j];
hasLabel[index] = index;
}
/** @type {number[]} The unique list of labels that were detected */
const labels = hasLabel.filter(x => x !== undefined);
toReturn.push({ segmentation, labels });
}
return toReturn;
}
/**
* Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and `labels`.
* @param {Tensor} class_logits The class logits.
* @param {Tensor} mask_logits The mask logits.
* @param {number} object_mask_threshold A number between 0 and 1 used to binarize the masks.
* @param {number} num_labels The number of labels.
* @returns {[Tensor[], number[], number[]]} The binarized masks, the scores, and the labels.
* @private
*/
function remove_low_and_no_objects(class_logits, mask_logits, object_mask_threshold, num_labels) {
const mask_probs_item = [];
const pred_scores_item = [];
const pred_labels_item = [];
for (let j = 0; j < class_logits.dims[0]; ++j) {
const cls = class_logits[j];
const mask = mask_logits[j];
const pred_label = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.max)(cls.data)[1];
if (pred_label === num_labels) {
// Is the background, so we ignore it
continue;
}
const scores = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.softmax)(cls.data);
const pred_score = scores[pred_label];
if (pred_score > object_mask_threshold) {
mask_probs_item.push(mask);
pred_scores_item.push(pred_score);
pred_labels_item.push(pred_label);
}
}
return [mask_probs_item, pred_scores_item, pred_labels_item];
}
/**
* Checks whether the segment is valid or not.
* @param {Int32Array} mask_labels Labels for each pixel in the mask.
* @param {Tensor[]} mask_probs Probabilities for each pixel in the masks.
* @param {number} k The class id of the segment.
* @param {number} mask_threshold The mask threshold.
* @param {number} overlap_mask_area_threshold The overlap mask area threshold.
* @returns {[boolean, number[]]} Whether the segment is valid or not, and the indices of the valid labels.
* @private
*/
function check_segment_validity(
mask_labels,
mask_probs,
k,
mask_threshold = 0.5,
overlap_mask_area_threshold = 0.8
) {
// mask_k is a 1D array of indices, indicating where the mask is equal to k
const mask_k = [];
let mask_k_area = 0;
let original_area = 0;
const mask_probs_k_data = mask_probs[k].data;
// Compute the area of all the stuff in query k
for (let i = 0; i < mask_labels.length; ++i) {
if (mask_labels[i] === k) {
mask_k.push(i);
++mask_k_area;
}
if (mask_probs_k_data[i] >= mask_threshold) {
++original_area;
}
}
let mask_exists = mask_k_area > 0 && original_area > 0;
// Eliminate disconnected tiny segments
if (mask_exists) {
// Perform additional check
let area_ratio = mask_k_area / original_area;
mask_exists = area_ratio > overlap_mask_area_threshold;
}
return [mask_exists, mask_k]
}
/**
* Computes the segments.
* @param {Tensor[]} mask_probs The mask probabilities.
* @param {number[]} pred_scores The predicted scores.
* @param {number[]} pred_labels The predicted labels.
* @param {number} mask_threshold The mask threshold.
* @param {number} overlap_mask_area_threshold The overlap mask area threshold.
* @param {Set<number>} label_ids_to_fuse The label ids to fuse.
* @param {number[]} target_size The target size of the image.
* @returns {[Tensor, Array<{id: number, label_id: number, score: number}>]} The computed segments.
* @private
*/
function compute_segments(
mask_probs,
pred_scores,
pred_labels,
mask_threshold,
overlap_mask_area_threshold,
label_ids_to_fuse = null,
target_size = null,
) {
const [height, width] = target_size ?? mask_probs[0].dims;
const segmentation = new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor(
'int32',
new Int32Array(height * width),
[height, width]
);
const segments = [];
// 1. If target_size is not null, we need to resize the masks to the target size
if (target_size !== null) {
// resize the masks to the target size
for (let i = 0; i < mask_probs.length; ++i) {
mask_probs[i] = (0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__.interpolate)(mask_probs[i], target_size, 'bilinear', false);
}
}
// 2. Weigh each mask by its prediction score
// NOTE: `mask_probs` is updated in-place
//
// Temporary storage for the best label/scores for each pixel ([height, width]):
const mask_labels = new Int32Array(mask_probs[0].data.length);
const bestScores = new Float32Array(mask_probs[0].data.length);
for (let i = 0; i < mask_probs.length; ++i) {
let score = pred_scores[i];
const mask_probs_i_data = mask_probs[i].data;
for (let j = 0; j < mask_probs_i_data.length; ++j) {
mask_probs_i_data[j] *= score
if (mask_probs_i_data[j] > bestScores[j]) {
mask_labels[j] = i;
bestScores[j] = mask_probs_i_data[j];
}
}
}
let current_segment_id = 0;
// let stuff_memory_list = {}
const segmentation_data = segmentation.data;
for (let k = 0; k < pred_labels.length; ++k) {
const pred_class = pred_labels[k];
// TODO add `should_fuse`
// let should_fuse = pred_class in label_ids_to_fuse
// Check if mask exists and large enough to be a segment
const [mask_exists, mask_k] = check_segment_validity(
mask_labels,
mask_probs,
k,
mask_threshold,
overlap_mask_area_threshold
)
if (!mask_exists) {
// Nothing to see here
continue;
}
// TODO
// if (pred_class in stuff_memory_list) {
// current_segment_id = stuff_memory_list[pred_class]
// } else {
// current_segment_id += 1;
// }
++current_segment_id;
// Add current object segment to final segmentation map
for (const index of mask_k) {
segmentation_data[index] = current_segment_id;
}
segments.push({
id: current_segment_id,
label_id: pred_class,
// was_fused: should_fuse, TODO
score: pred_scores[k],
})
// TODO
// if(should_fuse){
// stuff_memory_list[pred_class] = current_segment_id
// }
}
return [segmentation, segments];
}
/**
* Rescales the image so that the following conditions are met:
*
* 1. Both dimensions (height and width) are divisible by 'factor'.
* 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
* 3. The aspect ratio of the image is maintained as closely as possible.
*
* @param {number} height The height of the image.
* @param {number} width The width of the image.
* @param {number} [factor=28] The factor to use for resizing.
* @param {number} [min_pixels=56*56] The minimum number of pixels.
* @param {number} [max_pixels=14*14*4*1280] The maximum number of pixels.
* @returns {[number, number]} The new height and width of the image.
* @throws {Error} If the height or width is smaller than the factor.
*/
function smart_resize(height, width, factor = 28, min_pixels = 56 * 56, max_pixels = 14 * 14 * 4 * 1280) {
if (height < factor || width < factor) {
throw new Error(`height:${height} or width:${width} must be larger than factor:${factor}`);
} else if (Math.max(height, width) / Math.min(height, width) > 200) {
throw new Error(
`absolute aspect ratio must be smaller than 200, got ${Math.max(height, width) / Math.min(height, width)}`
);
}
let h_bar = Math.round(height / factor) * factor;
let w_bar = Math.round(width / factor) * factor;
if (h_bar * w_bar > max_pixels) {
const beta = Math.sqrt((height * width) / max_pixels);
h_bar = Math.floor((height / beta) / factor) * factor;
w_bar = Math.floor((width / beta) / factor) * factor;
} else if (h_bar * w_bar < min_pixels) {
const beta = Math.sqrt(min_pixels / (height * width));
h_bar = Math.ceil((height * beta) / factor) * factor;
w_bar = Math.ceil((width * beta) / factor) * factor;
}
return [h_bar, w_bar];
}
/**
* Post-process the model output to generate the final panoptic segmentation.
* @param {*} outputs The model output to post process
* @param {number} [threshold=0.5] The probability score threshold to keep predicted instance masks.
* @param {number} [mask_threshold=0.5] Threshold to use when turning the predicted masks into binary values.
* @param {number} [overlap_mask_area_threshold=0.8] The overlap mask area threshold to merge or discard small disconnected parts within each binary instance mask.
* @param {Set<number>} [label_ids_to_fuse=null] The labels in this state will have all their instances be fused together.
* @param {[number, number][]} [target_sizes=null] The target sizes to resize the masks to.
* @returns {Array<{ segmentation: Tensor, segments_info: Array<{id: number, label_id: number, score: number}>}>}
*/
function post_process_panoptic_segmentation(
outputs,
threshold = 0.5,
mask_threshold = 0.5,
overlap_mask_area_threshold = 0.8,
label_ids_to_fuse = null,
target_sizes = null,
) {
if (label_ids_to_fuse === null) {
console.warn("`label_ids_to_fuse` unset. No instance will be fused.")
label_ids_to_fuse = new Set();
}
const class_queries_logits = outputs.class_queries_logits ?? outputs.logits; // [batch_size, num_queries, num_classes+1]
const masks_queries_logits = outputs.masks_queries_logits ?? outputs.pred_masks; // [batch_size, num_queries, height, width]
const mask_probs = masks_queries_logits.sigmoid() // [batch_size, num_queries, height, width]
let [batch_size, num_queries, num_labels] = class_queries_logits.dims;
num_labels -= 1; // Remove last class (background)
if (target_sizes !== null && target_sizes.length !== batch_size) {
throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
}
let toReturn = [];
for (let i = 0; i < batch_size; ++i) {
let target_size = target_sizes !== null ? target_sizes[i] : null;
let class_logits = class_queries_logits[i];
let mask_logits = mask_probs[i];
let [mask_probs_item, pred_scores_item, pred_labels_item] = remove_low_and_no_objects(class_logits, mask_logits, threshold, num_labels);
if (pred_labels_item.length === 0) {
// No mask found
let [height, width] = target_size ?? mask_logits.dims.slice(-2);
let segmentation = new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor(
'int32',
new Int32Array(height * width).fill(-1),
[height, width]
)
toReturn.push({
segmentation: segmentation,
segments_info: []
});
continue;
}
// Get segmentation map and segment information of batch item
let [segmentation, segments] = compute_segments(
mask_probs_item,
pred_scores_item,
pred_labels_item,
mask_threshold,
overlap_mask_area_threshold,
label_ids_to_fuse,
target_size,
)
toReturn.push({
segmentation: segmentation,
segments_info: segments
})
}
return toReturn;
}
/**
* Post-processes the outputs of the model (for instance segmentation).
* @param {*} outputs Raw outputs of the model.
* @param {number} [threshold=0.5] The probability score threshold to keep predicted instance masks.
* @param {[number, number][]} [target_sizes=null] List of tuples corresponding to the requested final size
* (height, width) of each prediction. If unset, predictions will not be resized.
* @returns {Array<{ segmentation: Tensor, segments_info: Array<{id: number, label_id: number, score: number}>}>}
*/
function post_process_instance_segmentation(outputs, threshold = 0.5, target_sizes = null) {
throw new Error('`post_process_instance_segmentation` is not yet implemented.');
}
/**
* @typedef {Object} ImageProcessorConfig A configuration object used to create an image processor.
* @property {function} [progress_callback=null] If specified, this function will be called during model construction, to provide the user with progress updates.
* @property {number[]} [image_mean] The mean values for image normalization.
* @property {number[]} [image_std] The standard deviation values for image normalization.
* @property {boolean} [do_rescale] Whether to rescale the image pixel values to the [0,1] range.
* @property {number} [rescale_factor] The factor to use for rescaling the image pixel values.
* @property {boolean} [do_normalize] Whether to normalize the image pixel values.
* @property {boolean} [do_resize] Whether to resize the image.
* @property {number} [resample] What method to use for resampling.
* @property {number|Object} [size] The size to resize the image to.
* @property {number|Object} [image_size] The size to resize the image to (same as `size`).
* @property {boolean} [do_flip_channel_order=false] Whether to flip the color channels from RGB to BGR.
* Can be overridden by the `do_flip_channel_order` parameter in the `preprocess` method.
* @property {boolean} [do_center_crop] Whether to center crop the image to the specified `crop_size`.
* Can be overridden by `do_center_crop` in the `preprocess` method.
* @property {boolean} [do_thumbnail] Whether to resize the image using thumbnail method.
* @property {boolean} [keep_aspect_ratio] If `true`, the image is resized to the largest possible size such that the aspect ratio is preserved.
* Can be overidden by `keep_aspect_ratio` in `preprocess`.
* @property {number} [ensure_multiple_of] If `do_resize` is `true`, the image is resized to a size that is a multiple of this value.
* Can be overidden by `ensure_multiple_of` in `preprocess`.
*
* @property {number[]} [mean] The mean values for image normalization (same as `image_mean`).
* @property {number[]} [std] The standard deviation values for image normalization (same as `image_std`).
*/
class ImageProcessor extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_0__.Callable {
/**
* Constructs a new `ImageProcessor`.
* @param {ImageProcessorConfig} config The configuration object.
*/
constructor(config) {
super();
this.image_mean = config.image_mean ?? config.mean;
this.image_std = config.image_std ?? config.std;
this.resample = config.resample ?? 2; // 2 => bilinear
this.do_rescale = config.do_rescale ?? true;
this.rescale_factor = config.rescale_factor ?? (1 / 255);
this.do_normalize = config.do_normalize;
this.do_thumbnail = config.do_thumbnail;
this.size = config.size ?? config.image_size;
this.do_resize = config.do_resize ?? (this.size !== undefined);
this.size_divisibility = config.size_divisibility ?? config.size_divisor;
this.do_center_crop = config.do_center_crop;
this.crop_size = config.crop_size;
this.do_convert_rgb = config.do_convert_rgb ?? true;
this.do_crop_margin = config.do_crop_margin;
this.pad_size = config.pad_size;
this.do_pad = config.do_pad;
if (this.do_pad && !this.pad_size && this.size && this.size.width !== undefined && this.size.height !== undefined) {
// Should pad, but no pad size specified
// We infer the pad size from the resize size
this.pad_size = this.size
}
this.do_flip_channel_order = config.do_flip_channel_order ?? false;
this.config = config;
}
/**
* Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
* corresponding dimension of the specified size.
* @param {RawImage} image The image to be resized.
* @param {{height:number, width:number}} size The size `{"height": h, "width": w}` to resize the image to.
* @param {string | 0 | 1 | 2 | 3 | 4 | 5} [resample=2] The resampling filter to use.
* @returns {Promise<RawImage>} The resized image.
*/
async thumbnail(image, size, resample = 2) {
const input_height = image.height;
const input_width = image.width;
const output_height = size.height;
const output_width = size.width;
// We always resize to the smallest of either the input or output size.
let height = Math.min(input_height, output_height)
let width = Math.min(input_width, output_width)
if (height === input_height && width === input_width) {
return image;
}
if (input_height > input_width) {
width = Math.floor(input_width * height / input_height);
} else if (input_width > input_height) {
height = Math.floor(input_height * width / input_width);
}
return await image.resize(width, height, { resample });
}
/**
* Crops the margin of the image. Gray pixels are considered margin (i.e., pixels with a value below the threshold).
* @param {RawImage} image The image to be cropped.
* @param {number} gray_threshold Value below which pixels are considered to be gray.
* @returns {Promise<RawImage>} The cropped image.
*/
async crop_margin(image, gray_threshold = 200) {
const gray_image = image.clone().grayscale();
const minValue = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.min)(gray_image.data)[0];
const maxValue = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.max)(gray_image.data)[0];
const diff = maxValue - minValue;
if (diff === 0) {
return image;
}
const threshold = gray_threshold / 255;
let x_min = gray_image.width, y_min = gray_image.height, x_max = 0, y_max = 0;
const gray_image_data = gray_image.data;
for (let j = 0; j < gray_image.height; ++j) {
const row = j * gray_image.width;
for (let i = 0; i < gray_image.width; ++i) {
if ((gray_image_data[row + i] - minValue) / diff < threshold) {
// We have a non-zero pixel, so we update the min/max values accordingly
x_min = Math.min(x_min, i);
y_min = Math.min(y_min, j);
x_max = Math.max(x_max, i);
y_max = Math.max(y_max, j);
}
}
}
image = await image.crop([x_min, y_min, x_max, y_max]);
return image;
}
/**
* Pad the image by a certain amount.
* @param {Float32Array} pixelData The pixel data to pad.
* @param {number[]} imgDims The dimensions of the image (height, width, channels).
* @param {{width:number; height:number}|number} padSize The dimensions of the padded image.
* @param {Object} options The options for padding.
* @param {'constant'|'symmetric'} [options.mode='constant'] The type of padding to add.
* @param {boolean} [options.center=false] Whether to center the image.
* @param {number|number[]} [options.constant_values=0] The constant value to use for padding.
* @returns {[Float32Array, number[]]} The padded pixel data and image dimensions.
*/
pad_image(pixelData, imgDims, padSize, {
mode = 'constant',
center = false,
constant_values = 0,
} = {}) {
const [imageHeight, imageWidth, imageChannels] = imgDims;
let paddedImageWidth, paddedImageHeight;
if (typeof padSize === 'number') {
paddedImageWidth = padSize;
paddedImageHeight = padSize;
} else {
paddedImageWidth = padSize.width;
paddedImageHeight = padSize.height;
}
// Only add padding if there is a difference in size
if (paddedImageWidth !== imageWidth || paddedImageHeight !== imageHeight) {
const paddedPixelData = new Float32Array(paddedImageWidth * paddedImageHeight * imageChannels);
if (Array.isArray(constant_values)) {
// Fill with constant values, cycling through the array
for (let i = 0; i < paddedPixelData.length; ++i) {
paddedPixelData[i] = constant_values[i % imageChannels];
}
} else if (constant_values !== 0) {
paddedPixelData.fill(constant_values);
}
const [left, top] = center
? [Math.floor((paddedImageWidth - imageWidth) / 2), Math.floor((paddedImageHeight - imageHeight) / 2)]
: [0, 0];
// Copy the original image into the padded image
for (let i = 0; i < imageHeight; ++i) {
const a = (i + top) * paddedImageWidth;
const b = i * imageWidth;
for (let j = 0; j < imageWidth; ++j) {
const c = (a + j + left) * imageChannels;
const d = (b + j) * imageChannels;
for (let k = 0; k < imageChannels; ++k) {
paddedPixelData[c + k] = pixelData[d + k];
}
}
}
if (mode === 'symmetric') {
if (center) {
throw new Error('`center` padding is not supported when `mode` is set to `symmetric`.');
// TODO: Implement this
}
const h1 = imageHeight - 1;
const w1 = imageWidth - 1;
for (let i = 0; i < paddedImageHeight; ++i) {
const a = i * paddedImageWidth;
const b = (0,_utils_core_js__WEBPACK_IMPORTED_MODULE_4__.calculateReflectOffset)(i, h1) * imageWidth;
for (let j = 0; j < paddedImageWidth; ++j) {
if (i < imageHeight && j < imageWidth) continue; // Do not overwrite original image
const c = (a + j) * imageChannels;
const d = (b + (0,_utils_core_js__WEBPACK_IMPORTED_MODULE_4__.calculateReflectOffset)(j, w1)) * imageChannels;
// Copy channel-wise
for (let k = 0; k < imageChannels; ++k) {
paddedPixelData[c + k] = pixelData[d + k];
}
}
}
}
// Update pixel data and image dimensions
pixelData = paddedPixelData;
imgDims = [paddedImageHeight, paddedImageWidth, imageChannels]
}
return [pixelData, imgDims];
}
/**
* Rescale the image' pixel values by `this.rescale_factor`.
* @param {Float32Array} pixelData The pixel data to rescale.
* @returns {void}
*/
rescale(pixelData) {
for (let i = 0; i < pixelData.length; ++i) {
pixelData[i] = this.rescale_factor * pixelData[i];
}
}
/**
* Find the target (width, height) dimension of the output image after
* resizing given the input image and the desired size.
* @param {RawImage} image The image to resize.
* @param {any} size The size to use for resizing the image.
* @returns {[number, number]} The target (width, height) dimension of the output image after resizing.
*/
get_resize_output_image_size(image, size) {
// `size` comes in many forms, so we need to handle them all here:
// 1. `size` is an integer, in which case we resize the image to be a square
const [srcWidth, srcHeight] = image.size;
let shortest_edge;
let longest_edge;
if (this.do_thumbnail) {
// NOTE: custom logic for `Donut` models
const { height, width } = size;
shortest_edge = Math.min(height, width)
}
// Support both formats for backwards compatibility
else if (Number.isInteger(size)) {
shortest_edge = size;
longest_edge = this.config.max_size ?? shortest_edge;
} else if (size !== undefined) {
// Extract known properties from `size`
shortest_edge = size.shortest_edge;
longest_edge = size.longest_edge;
}
// If `longest_edge` and `shortest_edge` are set, maintain aspect ratio and resize to `shortest_edge`
// while keeping the largest dimension <= `longest_edge`
if (shortest_edge !== undefined || longest_edge !== undefined) {
// http://opensourcehacker.com/2011/12/01/calculate-aspect-ratio-conserving-resize-for-images-in-javascript/
// Try resize so that shortest edge is `shortest_edge` (target)
const shortResizeFactor = shortest_edge === undefined
? 1 // If `shortest_edge` is not set, don't upscale
: Math.max(shortest_edge / srcWidth, shortest_edge / srcHeight);
const newWidth = srcWidth * shortResizeFactor;
const newHeight = srcHeight * shortResizeFactor;
// The new width and height might be greater than `longest_edge`, so
// we downscale again to ensure the largest dimension is `longest_edge`
const longResizeFactor = longest_edge === undefined
? 1 // If `longest_edge` is not set, don't downscale
: Math.min(longest_edge / newWidth, longest_edge / newHeight);
// To avoid certain floating point precision issues, we round to 2 decimal places
let finalWidth = Math.floor(Number((newWidth * longResizeFactor).toFixed(2)));
let finalHeight = Math.floor(Number((newHeight * longResizeFactor).toFixed(2)));
if (this.size_divisibility !== undefined) {
[finalWidth, finalHeight] = enforce_size_divisibility([finalWidth, finalHeight], this.size_divisibility)
}
return [finalWidth, finalHeight];
} else if (size !== undefined && size.width !== undefined && size.height !== undefined) {
// If `width` and `height` are set, resize to those dimensions
let newWidth = size.width;
let newHeight = size.height;
// Custom for DPT models
if (this.config.keep_aspect_ratio && this.config.ensure_multiple_of) {
// determine new height and width
let scale_height = newHeight / srcHeight;
let scale_width = newWidth / srcWidth;
// scale as little as possible
if (Math.abs(1 - scale_width) < Math.abs(1 - scale_height)) {
// fit width
scale_height = scale_width;
} else {
// fit height
scale_width = scale_height;
}
newHeight = constraint_to_multiple_of(scale_height * srcHeight, this.config.ensure_multiple_of);
newWidth = constraint_to_multiple_of(scale_width * srcWidth, this.config.ensure_multiple_of);
}
return [newWidth, newHeight];
} else if (this.size_divisibility !== undefined) {
return enforce_size_divisibility([srcWidth, srcHeight], this.size_divisibility);
} else if (size.min_pixels !== undefined && size.max_pixels !== undefined) {
// Custom resize logic for Qwen2-VL models
const { min_pixels, max_pixels } = size;
const factor = this.config.patch_size * this.config.merge_size;
return smart_resize(srcHeight, srcWidth, factor, min_pixels, max_pixels);
} else {
throw new Error(`Could not resize image due to unsupported \`this.size\` option in config: ${JSON.stringify(size)}`);
}
}
/**
* Resizes the image.
* @param {RawImage} image The image to resize.
* @returns {Promise<RawImage>} The resized image.
*/
async resize(image) {
const [newWidth, newHeight] = this.get_resize_output_image_size(image, this.size);
return await image.resize(newWidth, newHeight, {
resample: this.resample,
});
}
/**
* @typedef {object} PreprocessedImage
* @property {HeightWidth} original_size The original size of the image.
* @property {HeightWidth} reshaped_input_size The reshaped input size of the image.
* @property {Tensor} pixel_values The pixel values of the preprocessed image.
*/
/**
* Preprocesses the given image.
*
* @param {RawImage} image The image to preprocess.
* @param {Object} overrides The overrides for the preprocessing options.
* @returns {Promise<PreprocessedImage>} The preprocessed image.
*/
async preprocess(image, {
do_normalize = null,
do_pad = null,
do_convert_rgb = null,
do_convert_grayscale = null,
do_flip_channel_order = null,
} = {}) {
if (this.do_crop_margin) {
// NOTE: Specific to nougat processors. This is done before resizing,
// and can be interpreted as a pre-preprocessing step.
image = await this.crop_margin(image);
}
const [srcWidth, srcHeight] = image.size; // original image size
// Convert image to RGB if specified in config.
if (do_convert_rgb ?? this.do_convert_rgb) {
image = image.rgb();
} else if (do_convert_grayscale) {
image = image.grayscale();
}
// TODO:
// For efficiency reasons, it might be best to merge the resize and center crop operations into one.
// Resize all images
if (this.do_resize) {
image = await this.resize(image);
}
// Resize the image using thumbnail method.
if (this.do_thumbnail) {
image = await this.thumbnail(image, this.size, this.resample);
}
if (this.do_center_crop) {
let crop_width;
let crop_height;
if (Number.isInteger(this.crop_size)) {
crop_width = this.crop_size;
crop_height = this.crop_size;
} else {
crop_width = this.crop_size.width;
crop_height = this.crop_size.height;
}
image = await image.center_crop(crop_width, crop_height);
}
/** @type {HeightWidth} */
const reshaped_input_size = [image.height, image.width];
// NOTE: All pixel-level manipulation (i.e., modifying `pixelData`)
// occurs with data in the hwc format (height, width, channels),
// to emulate the behavior of the original Python code (w/ numpy).
let pixelData = Float32Array.from(image.data);
let imgDims = [image.height, image.width, image.channels];
if (this.do_rescale) {
this.rescale(pixelData);
}
if (do_normalize ?? this.do_normalize) {
let image_mean = this.image_mean;
if (!Array.isArray(this.image_mean)) {
image_mean = new Array(image.channels).fill(image_mean);
}
let image_std = this.image_std;
if (!Array.isArray(this.image_std)) {
image_std = new Array(image.channels).fill(image_mean);
}
if (image_mean.length !== image.channels || image_std.length !== image.channels) {
throw new Error(`When set to arrays, the length of \`image_mean\` (${image_mean.length}) and \`image_std\` (${image_std.length}) must match the number of channels in the image (${image.channels}).`);
}
for (let i = 0; i < pixelData.length; i += image.channels) {
for (let j = 0; j < image.channels; ++j) {
pixelData[i + j] = (pixelData[i + j] - image_mean[j]) / image_std[j];
}
}
}
// do padding after rescaling/normalizing
if (do_pad ?? this.do_pad) {
if (this.pad_size) {
const padded = this.pad_image(pixelData, [image.height, image.width, image.channels], this.pad_size);
[pixelData, imgDims] = padded; // Update pixel data and image dimensions
} else if (this.size_divisibility) {
const [paddedWidth, paddedHeight] = enforce_size_divisibility([imgDims[1], imgDims[0]], this.size_divisibility);
[pixelData, imgDims] = this.pad_image(pixelData, imgDims, { width: paddedWidth, height: paddedHeight });
}
}
if (do_flip_channel_order ?? this.do_flip_channel_order) {
if (imgDims[2] !== 3) {
throw new Error('Flipping channel order is only supported for RGB images.');
}
// Convert RGB to BGR
for (let i = 0; i < pixelData.length; i += 3) {
const temp = pixelData[i];
pixelData[i] = pixelData[i + 2];
pixelData[i + 2] = temp;
}
}
const pixel_values = new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__.Tensor('float32', pixelData, imgDims)
.permute(2, 0, 1); // convert to channel dimension format (hwc -> chw)
return {
original_size: [srcHeight, srcWidth],
reshaped_input_size: reshaped_input_size,
pixel_values,
}
}
/**
* Calls the feature extraction process on an array of images,
* preprocesses each image, and concatenates the resulting
* features into a single Tensor.
* @param {RawImage[]} images The image(s) to extract features from.
* @param {...any} args Additional arguments.
* @returns {Promise<ImageProcessorResult>} An object containing the concatenated pixel values (and other metadata) of the preprocessed images.
*/
async _call(images, ...args) {
if (!Array.isArray(images)) {
images = [images];
}
/** @type {PreprocessedImage[]} */
const imageData = await Promise.all(images.map(x => this.preprocess(x)));
// Stack pixel values
const pixel_values = (0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__.stack)(imageData.map(x => x.pixel_values), 0);
return {
pixel_values,
// Original sizes of images
original_sizes: imageData.map(x => x.original_size),
// Reshaped sizes of images, before padding or cropping
reshaped_input_sizes: imageData.map(x => x.reshaped_input_size),
}
}
/**
* Instantiate one of the processor classes of the library from a pretrained model.
*
* The processor class to instantiate is selected based on the `image_processor_type` (or `feature_extractor_type`; legacy)
* property of the config object (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible)
*
* @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either:
* - A string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co.
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing processor files, e.g., `./my_model_directory/`.
* @param {import('../utils/hub.js').PretrainedOptions} options Additional options for loading the processor.
*
* @returns {Promise<ImageProcessor>} A new instance of the Processor class.
*/
static async from_pretrained(pretrained_model_name_or_path, options) {
const preprocessorConfig = await (0,_utils_hub_js__WEBPACK_IMPORTED_MODULE_5__.getModelJSON)(pretrained_model_name_or_path, _utils_constants_js__WEBPACK_IMPORTED_MODULE_6__.IMAGE_PROCESSOR_NAME, true, options);
return new this(preprocessorConfig);
}
}
/***/ }),
/***/ "./src/base/processing_utils.js":
/*!**************************************!*\
!*** ./src/base/processing_utils.js ***!
\**************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ Processor: () => (/* binding */ Processor)
/* harmony export */ });
/* harmony import */ var _utils_constants_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ../utils/constants.js */ "./src/utils/constants.js");
/* harmony import */ var _utils_generic_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ../utils/generic.js */ "./src/utils/generic.js");
/* harmony import */ var _utils_hub_js__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! ../utils/hub.js */ "./src/utils/hub.js");
/**
* @file Processors are used to prepare inputs (e.g., text, image or audio) for a model.
*
* **Example:** Using a `WhisperProcessor` to prepare an audio input for a model.
* ```javascript
* import { AutoProcessor, read_audio } from '@huggingface/transformers';
*
* const processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en');
* const audio = await read_audio('https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac', 16000);
* const { input_features } = await processor(audio);
* // Tensor {
* // data: Float32Array(240000) [0.4752984642982483, 0.5597258806228638, 0.56434166431427, ...],
* // dims: [1, 80, 3000],
* // type: 'float32',
* // size: 240000,
* // }
* ```
*
* @module processors
*/
/**
* @typedef {Object} ProcessorProperties Additional processor-specific properties.
* @typedef {import('../utils/hub.js').PretrainedOptions & ProcessorProperties} PretrainedProcessorOptions
*/
/**
* Represents a Processor that extracts features from an input.
*/
class Processor extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_1__.Callable {
static classes = [
'image_processor_class',
'tokenizer_class',
'feature_extractor_class',
]
static uses_processor_config = false;
/**
* Creates a new Processor with the given components
* @param {Object} config
* @param {Record<string, Object>} components
*/
constructor(config, components) {
super();
this.config = config;
this.components = components;
}
/**
* @returns {import('./image_processors_utils.js').ImageProcessor|undefined} The image processor of the processor, if it exists.
*/
get image_processor() {
return this.components.image_processor;
}
/**
* @returns {import('../tokenizers.js').PreTrainedTokenizer|undefined} The tokenizer of the processor, if it exists.
*/
get tokenizer() {
return this.components.tokenizer;
}
/**
* @returns {import('./feature_extraction_utils.js').FeatureExtractor|undefined} The feature extractor of the processor, if it exists.
*/
get feature_extractor() {
return this.components.feature_extractor;
}
apply_chat_template(messages, options = {}) {
if (!this.tokenizer) {
throw new Error('Unable to apply chat template without a tokenizer.');
}
return this.tokenizer.apply_chat_template(messages, {
tokenize: false, // default to false
...options,
});
}
batch_decode(...args) {
if (!this.tokenizer) {
throw new Error('Unable to decode without a tokenizer.');
}
return this.tokenizer.batch_decode(...args);
}
/**
* Calls the feature_extractor function with the given input.
* @param {any} input The input to extract features from.
* @param {...any} args Additional arguments.
* @returns {Promise<any>} A Promise that resolves with the extracted features.
*/
async _call(input, ...args) {
for (const item of [this.image_processor, this.feature_extractor, this.tokenizer]) {
if (item) {
return item(input, ...args);
}
}
throw new Error('No image processor, feature extractor, or tokenizer found.');
}
/**
* Instantiate one of the processor classes of the library from a pretrained model.
*
* The processor class to instantiate is selected based on the `feature_extractor_type` property of the config object
* (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible)
*
* @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either:
* - A string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co.
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing processor files, e.g., `./my_model_directory/`.
* @param {PretrainedProcessorOptions} options Additional options for loading the processor.
*
* @returns {Promise<Processor>} A new instance of the Processor class.
*/
static async from_pretrained(pretrained_model_name_or_path, options) {
const [config, components] = await Promise.all([
// TODO:
this.uses_processor_config
? (0,_utils_hub_js__WEBPACK_IMPORTED_MODULE_2__.getModelJSON)(pretrained_model_name_or_path, _utils_constants_js__WEBPACK_IMPORTED_MODULE_0__.PROCESSOR_NAME, true, options)
: {},
Promise.all(
this.classes
.filter((cls) => cls in this)
.map(async (cls) => {
const component = await this[cls].from_pretrained(pretrained_model_name_or_path, options);
return [cls.replace(/_class$/, ''), component];
})
).then(Object.fromEntries)
]);
return new this(config, components);
}
}
/***/ }),
/***/ "./src/configs.js":
/*!************************!*\
!*** ./src/configs.js ***!
\************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ AutoConfig: () => (/* binding */ AutoConfig),
/* harmony export */ PretrainedConfig: () => (/* binding */ PretrainedConfig),
/* harmony export */ getKeyValueShapes: () => (/* binding */ getKeyValueShapes)
/* harmony export */ });
/* harmony import */ var _utils_core_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./utils/core.js */ "./src/utils/core.js");
/* harmony import */ var _utils_hub_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ./utils/hub.js */ "./src/utils/hub.js");
/**
* @file Helper module for using model configs. For more information, see the corresponding
* [Python documentation](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoConfig).
*
* **Example:** Load an `AutoConfig`.
*
* ```javascript
* import { AutoConfig } from '@huggingface/transformers';
* const config = await AutoConfig.from_pretrained('bert-base-uncased');
* console.log(config);
* // PretrainedConfig {
* // "model_type": "bert",
* // "is_encoder_decoder": false,
* // "architectures": [
* // "BertForMaskedLM"
* // ],
* // "vocab_size": 30522
* // "num_attention_heads": 12,
* // "num_hidden_layers": 12,
* // "hidden_size": 768,
* // "max_position_embeddings": 512,
* // ...
* // }
* ```
*
* @module configs
*/
/**
* @typedef {import('./utils/hub.js').PretrainedOptions} PretrainedOptions
*/
/**
* @typedef {import('./utils/core.js').ProgressCallback} ProgressCallback
*/
/**
* @typedef {import('./utils/core.js').ProgressInfo} ProgressInfo
*/
/**
* Loads a config from the specified path.
* @param {string} pretrained_model_name_or_path The path to the config directory.
* @param {PretrainedOptions} options Additional options for loading the config.
* @returns {Promise<Object>} A promise that resolves with information about the loaded config.
*/
async function loadConfig(pretrained_model_name_or_path, options) {
return await (0,_utils_hub_js__WEBPACK_IMPORTED_MODULE_1__.getModelJSON)(pretrained_model_name_or_path, 'config.json', true, options);
}
/**
*
* @param {PretrainedConfig} config
* @returns {Object} The normalized configuration.
*/
function getNormalizedConfig(config) {
const mapping = {};
let init_normalized_config = {};
switch (config.model_type) {
// Sub-configs
case 'llava':
case 'paligemma':
case 'florence2':
case 'llava_onevision':
init_normalized_config = getNormalizedConfig(config.text_config);
break;
case 'moondream1':
init_normalized_config = getNormalizedConfig(config.phi_config);
break;
case 'musicgen':
init_normalized_config = getNormalizedConfig(config.decoder);
break;
case 'multi_modality':
init_normalized_config = getNormalizedConfig(config.language_config);
break;
// Decoder-only models
case 'gpt2':
case 'gptj':
case 'jais':
case 'codegen':
case 'gpt_bigcode':
mapping['num_heads'] = 'n_head';
mapping['num_layers'] = 'n_layer';
mapping['hidden_size'] = 'n_embd';
break;
case 'gpt_neox':
case 'stablelm':
case 'opt':
case 'phi':
case 'phi3':
case 'falcon':
mapping['num_heads'] = 'num_attention_heads';
mapping['num_layers'] = 'num_hidden_layers';
mapping['hidden_size'] = 'hidden_size';
break;
case 'llama':
case 'olmo':
case 'mobilellm':
case 'granite':
case 'cohere':
case 'mistral':
case 'starcoder2':
case 'qwen2':
case 'qwen2_vl':
mapping['num_heads'] = 'num_key_value_heads';
mapping['num_layers'] = 'num_hidden_layers';
mapping['hidden_size'] = 'hidden_size';
mapping['num_attention_heads'] = 'num_attention_heads';
break;
case 'gemma':
case 'gemma2':
mapping['num_heads'] = 'num_key_value_heads';
mapping['num_layers'] = 'num_hidden_layers';
mapping['dim_kv'] = 'head_dim';
break;
case 'openelm':
mapping['num_heads'] = 'num_kv_heads';
mapping['num_layers'] = 'num_transformer_layers';
mapping['dim_kv'] = 'head_dim';
break;
case 'gpt_neo':
case 'donut-swin':
mapping['num_heads'] = 'num_heads';
mapping['num_layers'] = 'num_layers';
mapping['hidden_size'] = 'hidden_size';
break;
case 'bloom':
mapping['num_heads'] = 'n_head';
mapping['num_layers'] = 'n_layer';
mapping['hidden_size'] = 'hidden_size';
break;
case 'mpt':
mapping['num_heads'] = 'n_heads';
mapping['num_layers'] = 'n_layers';
mapping['hidden_size'] = 'd_model';
break;
// Encoder-decoder models
case 't5':
case 'mt5':
case 'longt5':
mapping['num_decoder_layers'] = 'num_decoder_layers';
mapping['num_decoder_heads'] = 'num_heads';
mapping['decoder_dim_kv'] = 'd_kv';
mapping['num_encoder_layers'] = 'num_layers';
mapping['num_encoder_heads'] = 'num_heads';
mapping['encoder_dim_kv'] = 'd_kv';
break;
case 'bart':
case 'mbart':
case 'marian':
case 'whisper':
case 'm2m_100':
case 'blenderbot':
case 'blenderbot-small':
case 'florence2_language':
mapping['num_decoder_layers'] = 'decoder_layers';
mapping['num_decoder_heads'] = 'decoder_attention_heads';
mapping['decoder_hidden_size'] = 'd_model';
mapping['num_encoder_layers'] = 'encoder_layers';
mapping['num_encoder_heads'] = 'encoder_attention_heads';
mapping['encoder_hidden_size'] = 'd_model';
break;
case 'speecht5':
mapping['num_decoder_layers'] = 'decoder_layers';
mapping['num_decoder_heads'] = 'decoder_attention_heads';
mapping['decoder_hidden_size'] = 'hidden_size';
mapping['num_encoder_layers'] = 'encoder_layers';
mapping['num_encoder_heads'] = 'encoder_attention_heads';
mapping['encoder_hidden_size'] = 'hidden_size';
break;
case 'trocr':
mapping['num_encoder_layers'] = mapping['num_decoder_layers'] = 'decoder_layers';
mapping['num_encoder_heads'] = mapping['num_decoder_heads'] = 'decoder_attention_heads';
mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'd_model';
break;
case 'musicgen_decoder':
mapping['num_encoder_layers'] = mapping['num_decoder_layers'] = 'num_hidden_layers';
mapping['num_encoder_heads'] = mapping['num_decoder_heads'] = 'num_attention_heads';
mapping['encoder_hidden_size'] = mapping['decoder_hidden_size'] = 'hidden_size';
break;
case 'vision-encoder-decoder':
const decoderConfig = getNormalizedConfig(config.decoder);
const add_encoder_pkv = 'num_decoder_layers' in decoderConfig;
const result = (0,_utils_core_js__WEBPACK_IMPORTED_MODULE_0__.pick)(config, ['model_type', 'is_encoder_decoder']);
if (add_encoder_pkv) {
// Decoder is part of an encoder-decoder model
result.num_decoder_layers = decoderConfig.num_decoder_layers;
result.num_decoder_heads = decoderConfig.num_decoder_heads;
result.decoder_hidden_size = decoderConfig.decoder_hidden_size;
result.num_encoder_layers = decoderConfig.num_encoder_layers;
result.num_encoder_heads = decoderConfig.num_encoder_heads;
result.encoder_hidden_size = decoderConfig.encoder_hidden_size;
} else {
// Decoder is a decoder-only model
result.num_layers = decoderConfig.num_layers;
result.num_heads = decoderConfig.num_heads;
result.hidden_size = decoderConfig.hidden_size;
}
return result;
}
// NOTE: If `num_attention_heads` is not set, it is assumed to be equal to `num_heads`
const normalized_config = {
...init_normalized_config,
...(0,_utils_core_js__WEBPACK_IMPORTED_MODULE_0__.pick)(config, ['model_type', 'multi_query', 'is_encoder_decoder']),
};
for (const key in mapping) {
normalized_config[key] = config[mapping[key]];
}
return normalized_config;
}
/**
*
* @param {PretrainedConfig} config
* @returns {Record<string, number[]>}
*/
function getKeyValueShapes(config, {
prefix = 'past_key_values',
batch_size=1,
} = {}) {
/** @type {Record<string, number[]>} */
const decoderFeeds = {};
const normalized_config = config.normalized_config;
if (normalized_config.is_encoder_decoder && (
'num_encoder_heads' in normalized_config && 'num_decoder_heads' in normalized_config
)) {
const encoder_dim_kv = normalized_config.encoder_dim_kv ?? (
normalized_config.encoder_hidden_size / normalized_config.num_encoder_heads
);
const decoder_dim_kv = normalized_config.decoder_dim_kv ?? (
normalized_config.decoder_hidden_size / normalized_config.num_decoder_heads
);
const encoder_dims = [batch_size, normalized_config.num_encoder_heads, 0, encoder_dim_kv];
const decoder_dims = [batch_size, normalized_config.num_decoder_heads, 0, decoder_dim_kv];
for (let i = 0; i < normalized_config.num_decoder_layers; ++i) {
decoderFeeds[`${prefix}.${i}.encoder.key`] = encoder_dims;
decoderFeeds[`${prefix}.${i}.encoder.value`] = encoder_dims;
decoderFeeds[`${prefix}.${i}.decoder.key`] = decoder_dims;
decoderFeeds[`${prefix}.${i}.decoder.value`] = decoder_dims;
}
} else { // Decoders
const num_heads = normalized_config.num_heads;
const num_layers = normalized_config.num_layers;
const dim_kv = normalized_config.dim_kv ?? (
normalized_config.hidden_size /
(normalized_config.num_attention_heads ?? num_heads)
);
if (normalized_config.model_type === 'falcon') {
// NOTE: Custom implementation for Falcon
const dims = [batch_size * num_heads, 0, dim_kv]
for (let i = 0; i < num_layers; ++i) {
decoderFeeds[`${prefix}.${i}.key`] = dims;
decoderFeeds[`${prefix}.${i}.value`] = dims;
}
} else if (normalized_config.multi_query) { // e.g., for `gpt_bigcode`
const dims = [batch_size * num_heads, 0, 2 * dim_kv]
for (let i = 0; i < num_layers; ++i) {
decoderFeeds[`${prefix}.${i}.key_value`] = dims;
}
} else if (normalized_config.model_type === 'bloom') {
// NOTE: Custom implementation for Bloom
const keyDims = [batch_size * num_heads, dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
const valueDims = [batch_size * num_heads, 0, dim_kv] // [batch_size x num_heads,past_sequence_length,64]
for (let i = 0; i < num_layers; ++i) {
decoderFeeds[`${prefix}.${i}.key`] = keyDims;
decoderFeeds[`${prefix}.${i}.value`] = valueDims;
}
} else if (normalized_config.model_type === 'openelm') {
for (let i = 0; i < num_layers; ++i) {
const dims = [batch_size, num_heads[i], 0, dim_kv]
decoderFeeds[`${prefix}.${i}.key`] = dims;
decoderFeeds[`${prefix}.${i}.value`] = dims;
}
} else { // Decoder-only
const dims = [batch_size, num_heads, 0, dim_kv]
for (let i = 0; i < num_layers; ++i) {
decoderFeeds[`${prefix}.${i}.key`] = dims;
decoderFeeds[`${prefix}.${i}.value`] = dims;
}
}
}
return decoderFeeds;
}
/**
* Base class for all configuration classes. For more information, see the corresponding
* [Python documentation](https://huggingface.co/docs/transformers/main/en/main_classes/configuration#transformers.PretrainedConfig).
*/
class PretrainedConfig {
// NOTE: Typo in original
/** @type {string|null} */
model_type = null;
/** @type {boolean} */
is_encoder_decoder = false;
/** @type {number} */
max_position_embeddings;
/** @type {TransformersJSConfig} */
'transformers.js_config';
/**
* Create a new PreTrainedTokenizer instance.
* @param {Object} configJSON The JSON of the config.
*/
constructor(configJSON) {
Object.assign(this, configJSON);
this.normalized_config = getNormalizedConfig(this);
}
/**
* Loads a pre-trained config from the given `pretrained_model_name_or_path`.
*
* @param {string} pretrained_model_name_or_path The path to the pre-trained config.
* @param {PretrainedOptions} options Additional options for loading the config.
* @throws {Error} Throws an error if the config.json is not found in the `pretrained_model_name_or_path`.
*
* @returns {Promise<PretrainedConfig>} A new instance of the `PretrainedConfig` class.
*/
static async from_pretrained(pretrained_model_name_or_path, {
progress_callback = null,
config = null,
cache_dir = null,
local_files_only = false,
revision = 'main',
} = {}) {
if (config && !(config instanceof PretrainedConfig)) {
config = new PretrainedConfig(config);
}
const data = config ?? await loadConfig(pretrained_model_name_or_path, {
progress_callback,
config,
cache_dir,
local_files_only,
revision,
})
return new this(data);
}
}
/**
* Helper class which is used to instantiate pretrained configs with the `from_pretrained` function.
*
* @example
* const config = await AutoConfig.from_pretrained('Xenova/bert-base-uncased');
*/
class AutoConfig {
/** @type {typeof PretrainedConfig.from_pretrained} */
static async from_pretrained(...args) {
return PretrainedConfig.from_pretrained(...args);
}
}
/**
* Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
* @typedef {Object} TransformersJSConfig
* @property {import('./utils/tensor.js').DataType|Record<import('./utils/dtypes.js').DataType, import('./utils/tensor.js').DataType>} [kv_cache_dtype] The data type of the key-value cache.
* @property {Record<string, number>} [free_dimension_overrides] Override the free dimensions of the model.
* See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides
* for more information.
* @property {import('./utils/devices.js').DeviceType} [device] The default device to use for the model.
* @property {import('./utils/dtypes.js').DataType} [dtype] The default data type to use for the model.
* @property {boolean|Record<string, boolean>} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size).
*/
/***/ }),
/***/ "./src/env.js":
/*!********************!*\
!*** ./src/env.js ***!
\********************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ apis: () => (/* binding */ apis),
/* harmony export */ env: () => (/* binding */ env)
/* harmony export */ });
/* harmony import */ var fs__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! fs */ "?569f");
/* harmony import */ var path__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! path */ "?3f59");
/* harmony import */ var url__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! url */ "?154a");
/**
* @file Module used to configure Transformers.js.
*
* **Example:** Disable remote models.
* ```javascript
* import { env } from '@huggingface/transformers';
* env.allowRemoteModels = false;
* ```
*
* **Example:** Set local model path.
* ```javascript
* import { env } from '@huggingface/transformers';
* env.localModelPath = '/path/to/local/models/';
* ```
*
* **Example:** Set cache directory.
* ```javascript
* import { env } from '@huggingface/transformers';
* env.cacheDir = '/path/to/cache/directory/';
* ```
*
* @module env
*/
const VERSION = '3.1.0';
// Check if various APIs are available (depends on environment)
const IS_BROWSER_ENV = typeof self !== 'undefined';
const IS_WEBWORKER_ENV = IS_BROWSER_ENV && self.constructor.name === 'DedicatedWorkerGlobalScope';
const IS_WEB_CACHE_AVAILABLE = IS_BROWSER_ENV && 'caches' in self;
const IS_WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator;
const IS_WEBNN_AVAILABLE = typeof navigator !== 'undefined' && 'ml' in navigator;
const IS_PROCESS_AVAILABLE = typeof process !== 'undefined';
const IS_NODE_ENV = IS_PROCESS_AVAILABLE && process?.release?.name === 'node';
const IS_FS_AVAILABLE = !isEmpty(fs__WEBPACK_IMPORTED_MODULE_0__);
const IS_PATH_AVAILABLE = !isEmpty(path__WEBPACK_IMPORTED_MODULE_1__);
/**
* A read-only object containing information about the APIs available in the current environment.
*/
const apis = Object.freeze({
/** Whether we are running in a browser environment */
IS_BROWSER_ENV,
/** Whether we are running in a web worker environment */
IS_WEBWORKER_ENV,
/** Whether the Cache API is available */
IS_WEB_CACHE_AVAILABLE,
/** Whether the WebGPU API is available */
IS_WEBGPU_AVAILABLE,
/** Whether the WebNN API is available */
IS_WEBNN_AVAILABLE,
/** Whether the Node.js process API is available */
IS_PROCESS_AVAILABLE,
/** Whether we are running in a Node.js environment */
IS_NODE_ENV,
/** Whether the filesystem API is available */
IS_FS_AVAILABLE,
/** Whether the path API is available */
IS_PATH_AVAILABLE,
});
const RUNNING_LOCALLY = IS_FS_AVAILABLE && IS_PATH_AVAILABLE;
let dirname__ = './';
if (RUNNING_LOCALLY) {
// NOTE: We wrap `import.meta` in a call to `Object` to prevent Webpack from trying to bundle it in CommonJS.
// Although we get the warning: "Accessing import.meta directly is unsupported (only property access or destructuring is supported)",
// it is safe to ignore since the bundled value (`{}`) isn't used for CommonJS environments (we use __dirname instead).
const _import_meta_url = Object(import.meta).url;
if (_import_meta_url) {
dirname__ = path__WEBPACK_IMPORTED_MODULE_1__.dirname(path__WEBPACK_IMPORTED_MODULE_1__.dirname(url__WEBPACK_IMPORTED_MODULE_2__.fileURLToPath(_import_meta_url))) // ESM
} else if (typeof __dirname !== 'undefined') {
dirname__ = path__WEBPACK_IMPORTED_MODULE_1__.dirname(__dirname) // CommonJS
}
}
// Only used for environments with access to file system
const DEFAULT_CACHE_DIR = RUNNING_LOCALLY
? path__WEBPACK_IMPORTED_MODULE_1__.join(dirname__, '/.cache/')
: null;
// Set local model path, based on available APIs
const DEFAULT_LOCAL_MODEL_PATH = '/models/';
const localModelPath = RUNNING_LOCALLY
? path__WEBPACK_IMPORTED_MODULE_1__.join(dirname__, DEFAULT_LOCAL_MODEL_PATH)
: DEFAULT_LOCAL_MODEL_PATH;
/**
* Global variable given visible to users to control execution. This provides users a simple way to configure Transformers.js.
* @typedef {Object} TransformersEnvironment
* @property {string} version This version of Transformers.js.
* @property {{onnx: Partial<import('onnxruntime-common').Env>}} backends Expose environment variables of different backends,
* allowing users to set these variables if they want to.
* @property {boolean} allowRemoteModels Whether to allow loading of remote files, defaults to `true`.
* If set to `false`, it will have the same effect as setting `local_files_only=true` when loading pipelines, models, tokenizers, processors, etc.
* @property {string} remoteHost Host URL to load models from. Defaults to the Hugging Face Hub.
* @property {string} remotePathTemplate Path template to fill in and append to `remoteHost` when loading models.
* @property {boolean} allowLocalModels Whether to allow loading of local files, defaults to `false` if running in-browser, and `true` otherwise.
* If set to `false`, it will skip the local file check and try to load the model from the remote host.
* @property {string} localModelPath Path to load local models from. Defaults to `/models/`.
* @property {boolean} useFS Whether to use the file system to load files. By default, it is `true` if available.
* @property {boolean} useBrowserCache Whether to use Cache API to cache models. By default, it is `true` if available.
* @property {boolean} useFSCache Whether to use the file system to cache files. By default, it is `true` if available.
* @property {string} cacheDir The directory to use for caching files with the file system. By default, it is `./.cache`.
* @property {boolean} useCustomCache Whether to use a custom cache system (defined by `customCache`), defaults to `false`.
* @property {Object} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which
* implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache
*/
/** @type {TransformersEnvironment} */
const env = {
version: VERSION,
/////////////////// Backends settings ///////////////////
// NOTE: These will be populated later by the backends themselves.
backends: {
// onnxruntime-web/onnxruntime-node
onnx: {},
},
/////////////////// Model settings ///////////////////
allowRemoteModels: true,
remoteHost: 'https://huggingface.co/',
remotePathTemplate: '{model}/resolve/{revision}/',
allowLocalModels: !IS_BROWSER_ENV,
localModelPath: localModelPath,
useFS: IS_FS_AVAILABLE,
/////////////////// Cache settings ///////////////////
useBrowserCache: IS_WEB_CACHE_AVAILABLE,
useFSCache: IS_FS_AVAILABLE,
cacheDir: DEFAULT_CACHE_DIR,
useCustomCache: false,
customCache: null,
//////////////////////////////////////////////////////
}
/**
* @param {Object} obj
* @private
*/
function isEmpty(obj) {
return Object.keys(obj).length === 0;
}
/***/ }),
/***/ "./src/generation/configuration_utils.js":
/*!***********************************************!*\
!*** ./src/generation/configuration_utils.js ***!
\***********************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ GenerationConfig: () => (/* binding */ GenerationConfig)
/* harmony export */ });
/* harmony import */ var _utils_core_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ../utils/core.js */ "./src/utils/core.js");
/**
* @module generation/configuration_utils
*/
/**
* Class that holds a configuration for a generation task.
*/
class GenerationConfig {
// Parameters that control the length of the output
/**
* The maximum length the generated tokens can have.
* Corresponds to the length of the input prompt + `max_new_tokens`.
* Its effect is overridden by `max_new_tokens`, if also set.
* @type {number}
* @default 20
*/
max_length = 20;
/**
* The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
* @type {number}
* @default null
*/
max_new_tokens = null;
/**
* The minimum length of the sequence to be generated.
* Corresponds to the length of the input prompt + `min_new_tokens`.
* Its effect is overridden by `min_new_tokens`, if also set.
* @type {number}
* @default 0
*/
min_length = 0;
/**
* The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
* @type {number}
* @default null
*/
min_new_tokens = null;
/**
* Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
* - `true`, where the generation stops as soon as there are `num_beams` complete candidates;
* - `false`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates;
* - `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm).
* @type {boolean|"never"}
* @default false
*/
early_stopping = false;
/**
* The maximum amount of time you allow the computation to run for in seconds.
* Generation will still finish the current pass after allocated time has been passed.
* @type {number}
* @default null
*/
max_time = null;
// Parameters that control the generation strategy used
/**
* Whether or not to use sampling; use greedy decoding otherwise.
* @type {boolean}
* @default false
*/
do_sample = false;
/**
* Number of beams for beam search. 1 means no beam search.
* @type {number}
* @default 1
*/
num_beams = 1;
/**
* Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
* See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
* @type {number}
* @default 1
*/
num_beam_groups = 1;
/**
* The values balance the model confidence and the degeneration penalty in contrastive search decoding.
* @type {number}
* @default null
*/
penalty_alpha = null;
/**
* Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
* @type {boolean}
* @default true
*/
use_cache = true;
// Parameters for manipulation of the model output logits
/**
* The value used to modulate the next token probabilities.
* @type {number}
* @default 1.0
*/
temperature = 1.0;
/**
* The number of highest probability vocabulary tokens to keep for top-k-filtering.
* @type {number}
* @default 50
*/
top_k = 50;
/**
* If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
* @type {number}
* @default 1.0
*/
top_p = 1.0;
/**
* Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated.
* If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to `typical_p` or higher are kept for generation.
* See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
* @type {number}
* @default 1.0
*/
typical_p = 1.0;
/**
* If set to float strictly between 0 and 1, only tokens with a conditional probability greater than `epsilon_cutoff` will be sampled.
* In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model.
* See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
* @type {number}
* @default 0.0
*/
epsilon_cutoff = 0.0;
/**
* Eta sampling is a hybrid of locally typical sampling and epsilon sampling.
* If set to float strictly between 0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits)))`.
* The latter term is intuitively the expected next token probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model.
* See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
* @type {number}
* @default 0.0
*/
eta_cutoff = 0.0;
/**
* This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time.
* Note that `diversity_penalty` is only effective if `group beam search` is enabled.
* @type {number}
* @default 0.0
*/
diversity_penalty = 0.0;
/**
* The parameter for repetition penalty. 1.0 means no penalty.
* See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
* @type {number}
* @default 1.0
*/
repetition_penalty = 1.0;
/**
* The paramater for encoder_repetition_penalty.
* An exponential penalty on sequences that are not in the original input.
* 1.0 means no penalty.
* @type {number}
* @default 1.0
*/
encoder_repetition_penalty = 1.0;
/**
* Exponential penalty to the length that is used with beam-based generation.
* It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence.
* Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
* @type {number}
* @default 1.0
*/
length_penalty = 1.0;
/**
* If set to int > 0, all ngrams of that size can only occur once.
* @type {number}
* @default 0
*/
no_repeat_ngram_size = 0;
/**
* List of token ids that are not allowed to be generated.
* In order to get the token ids of the words that should not appear in the generated text, use
* `tokenizer(bad_words, { add_prefix_space: true, add_special_tokens: false }).input_ids`.
* @type {number[][]}
* @default null
*/
bad_words_ids = null;
/**
* List of token ids that must be generated.
* If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`.
* If given `number[][][]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word.
* @type {number[][]|number[][][]}
* @default null
*/
force_words_ids = null;
/**
* Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones).
* It's highly recommended to set this flag to `true` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization.
* @type {boolean}
* @default false
*/
renormalize_logits = false;
/**
* Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
* @type {Object[]}
* @default null
*/
constraints = null;
/**
* The id of the token to force as the first generated token after the `decoder_start_token_id`.
* Useful for multilingual models like mBART where the first generated token needs to be the target language token.
* @type {number}
* @default null
*/
forced_bos_token_id = null;
/**
* The id of the token to force as the last generated token when `max_length` is reached.
* Optionally, use a list to set multiple *end-of-sequence* tokens.
* @type {number|number[]}
* @default null
*/
forced_eos_token_id = null;
/**
* Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation.
* @type {boolean}
*/
remove_invalid_values = false;
/**
* This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated.
* The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay.
* @type {[number, number]}
* @default null
*/
exponential_decay_length_penalty = null;
/**
* A list of tokens that will be suppressed at generation.
* The `SuppressTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
* @type {number[]}
* @default null
*/
suppress_tokens = null;
/**
* A list of tokens that will be suppressed at the beginning of the generation.
* The `SuppressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled.
* @type {number[]}
* @default null
*/
begin_suppress_tokens = null;
/**
* A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling.
* For example, `[[1, 123]]` means the second generated token will always be a token of index 123.
* @type {[number, number][]}
* @default null
*/
forced_decoder_ids = null;
/**
* The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
* Higher guidance scale encourages the model to generate samples that are more closely linked to the input
* prompt, usually at the expense of poorer quality.
* @type {number}
* @default null
*/
guidance_scale = null;
// Parameters that define the output variables of `generate`
/**
* The number of independently computed returned sequences for each element in the batch.
* @type {number}
* @default 1
*/
num_return_sequences = 1;
/**
* Whether or not to return the attentions tensors of all attention layers.
* See `attentions` under returned tensors for more details.
* @type {boolean}
* @default false
*/
output_attentions = false;
/**
* Whether or not to return the hidden states of all layers.
* See `hidden_states` under returned tensors for more details.
* @type {boolean}
* @default false
*/
output_hidden_states = false;
/**
* Whether or not to return the prediction scores.
* See `scores` under returned tensors for more details.
* @type {boolean}
* @default false
*/
output_scores = false;
/**
* Whether or not to return a `ModelOutput` instead of a plain tuple.
* @type {boolean}
* @default false
*/
return_dict_in_generate = false;
// Special tokens that can be used at generation time
/**
* The id of the *padding* token.
* @type {number}
* @default null
*/
pad_token_id = null;
/**
* The id of the *beginning-of-sequence* token.
* @type {number}
* @default null
*/
bos_token_id = null;
/**
* The id of the *end-of-sequence* token.
* Optionally, use a list to set multiple *end-of-sequence* tokens.
* @type {number|number[]}
* @default null
*/
eos_token_id = null;
// Generation parameters exclusive to encoder-decoder models
/**
* If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.
* @type {number}
* @default 0
*/
encoder_no_repeat_ngram_size = 0;
/**
* If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
* @type {number}
* @default null
*/
decoder_start_token_id = null;
// Wild card
/**
* Additional generation kwargs will be forwarded to the `generate` function of the model.
* Kwargs that are not present in `generate`'s signature will be used in the model forward pass.
* @type {Object}
* @default {}
*/
generation_kwargs = {};
/**
*
* @param {GenerationConfig|import('../configs.js').PretrainedConfig} config
*/
constructor(config) {
Object.assign(this, (0,_utils_core_js__WEBPACK_IMPORTED_MODULE_0__.pick)(config, Object.getOwnPropertyNames(this)));
}
}
/***/ }),
/***/ "./src/generation/logits_process.js":
/*!******************************************!*\
!*** ./src/generation/logits_process.js ***!
\******************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ ClassifierFreeGuidanceLogitsProcessor: () => (/* binding */ ClassifierFreeGuidanceLogitsProcessor),
/* harmony export */ ForcedBOSTokenLogitsProcessor: () => (/* binding */ ForcedBOSTokenLogitsProcessor),
/* harmony export */ ForcedEOSTokenLogitsProcessor: () => (/* binding */ ForcedEOSTokenLogitsProcessor),
/* harmony export */ LogitsProcessor: () => (/* binding */ LogitsProcessor),
/* harmony export */ LogitsProcessorList: () => (/* binding */ LogitsProcessorList),
/* harmony export */ LogitsWarper: () => (/* binding */ LogitsWarper),
/* harmony export */ MinLengthLogitsProcessor: () => (/* binding */ MinLengthLogitsProcessor),
/* harmony export */ MinNewTokensLengthLogitsProcessor: () => (/* binding */ MinNewTokensLengthLogitsProcessor),
/* harmony export */ NoBadWordsLogitsProcessor: () => (/* binding */ NoBadWordsLogitsProcessor),
/* harmony export */ NoRepeatNGramLogitsProcessor: () => (/* binding */ NoRepeatNGramLogitsProcessor),
/* harmony export */ RepetitionPenaltyLogitsProcessor: () => (/* binding */ RepetitionPenaltyLogitsProcessor),
/* harmony export */ SuppressTokensAtBeginLogitsProcessor: () => (/* binding */ SuppressTokensAtBeginLogitsProcessor),
/* harmony export */ TemperatureLogitsWarper: () => (/* binding */ TemperatureLogitsWarper),
/* harmony export */ TopKLogitsWarper: () => (/* binding */ TopKLogitsWarper),
/* harmony export */ TopPLogitsWarper: () => (/* binding */ TopPLogitsWarper),
/* harmony export */ WhisperTimeStampLogitsProcessor: () => (/* binding */ WhisperTimeStampLogitsProcessor)
/* harmony export */ });
/* harmony import */ var _utils_generic_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ../utils/generic.js */ "./src/utils/generic.js");
/* harmony import */ var _utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ../utils/tensor.js */ "./src/utils/tensor.js");
/* harmony import */ var _utils_maths_js__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! ../utils/maths.js */ "./src/utils/maths.js");
/**
* @module generation/logits_process
*/
/**
* Abstract base class for all logit processors that can be applied during generation.
*/
class LogitsProcessor extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_0__.Callable {
/**
* Apply the processor to the input logits.
*
* @abstract
* @param {bigint[][]} input_ids The input ids.
* @param {Tensor} logits The logits to process.
* @throws {Error} Throws an error if `_call` is not implemented in the subclass.
*/
_call(input_ids, logits) {
throw Error("`_call` should be implemented in a subclass")
}
}
/**
* Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.
*/
class LogitsWarper extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_0__.Callable {
/**
* Apply the processor to the input logits.
*
* @abstract
* @param {bigint[][]} input_ids The input ids.
* @param {Tensor} logits The logits to process.
* @throws {Error} Throws an error if `_call` is not implemented in the subclass.
*/
_call(input_ids, logits) {
throw Error("`_call` should be implemented in a subclass")
}
}
/**
* A class representing a list of logits processors. A logits processor is a function that modifies the logits
* output of a language model. This class provides methods for adding new processors and applying all processors to a
* batch of logits.
*/
class LogitsProcessorList extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_0__.Callable {
/**
* Constructs a new instance of `LogitsProcessorList`.
*/
constructor() {
super();
this.processors = [];
}
/**
* Adds a new logits processor to the list.
*
* @param {LogitsProcessor} item The logits processor function to add.
*/
push(item) {
this.processors.push(item);
}
/**
* Adds multiple logits processors to the list.
*
* @param {LogitsProcessor[]} items The logits processor functions to add.
*/
extend(items) {
this.processors.push(...items);
}
/**
* Applies all logits processors in the list to a batch of logits, modifying them in-place.
*
* @param {bigint[][]} input_ids The input IDs for the language model.
* @param {Tensor} logits
*/
_call(input_ids, logits) {
let toReturn = logits;
// NOTE: Most processors modify logits inplace
for (const processor of this.processors) {
toReturn = processor(input_ids, toReturn);
}
return toReturn;
}
[Symbol.iterator]() {
return this.processors.values();
}
}
// DEPRECATED: https://github.com/huggingface/transformers/pull/29485
// /**
// * A logits processor that forces a specific token to be generated by the decoder.
// */
// export class ForceTokensLogitsProcessor extends LogitsProcessor {
// /**
// * Constructs a new instance of `ForceTokensLogitsProcessor`.
// *
// * @param {[number, number][]} forced_decoder_ids The ids of tokens that should be forced.
// */
// constructor(forced_decoder_ids) {
// super();
// // TODO: convert to `new Map(forced_decoder_ids)`
// this.force_token_map = Object.fromEntries(forced_decoder_ids ?? []);
// }
// /**
// * Apply the processor to the input logits.
// *
// * @param {bigint[][]} input_ids The input ids.
// * @param {Tensor} logits The logits to process.
// * @returns {Tensor} The processed logits.
// */
// _call(input_ids, logits) {
// console.log('this.force_token_map', this.force_token_map)
// console.log('call ForceTokensLogitsProcessor', input_ids, logits)
// console.log('input_ids.length', input_ids.length)
// let map = this.force_token_map[input_ids.length];
// if (map) { // There exists a mapping
// logits.data.fill(-Infinity)
// logits.data[map] = 0;
// }
// console.log('map', map)
// // throw Error("Not implemented")
// return logits;
// }
// }
/**
* A LogitsProcessor that forces a BOS token at the beginning of the generated sequence.
*/
class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
/**
* Create a ForcedBOSTokenLogitsProcessor.
* @param {number} bos_token_id The ID of the beginning-of-sequence token to be forced.
*/
constructor(bos_token_id) {
super();
this.bos_token_id = bos_token_id;
}
/**
* Apply the BOS token forcing to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with BOS token forcing.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
if (input_ids[i].length === 1) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
batch_logits_data.fill(-Infinity);
batch_logits_data[this.bos_token_id] = 0;
}
}
return logits;
}
}
/**
* A logits processor that enforces the specified token as the last generated token when `max_length` is reached.
*/
class ForcedEOSTokenLogitsProcessor extends LogitsProcessor {
/**
* Create a ForcedEOSTokenLogitsProcessor.
* @param {number} max_length The maximum length of the sequence to be generated.
* @param {number|number[]} eos_token_id The id(s) of the *end-of-sequence* token.
*/
constructor(max_length, eos_token_id) {
super();
this.max_length = max_length;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}
/**
* Apply the processor to input_ids and logits.
*
* @param {bigint[][]} input_ids The input ids.
* @param {Tensor} logits The logits tensor.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
if (input_ids[i].length === this.max_length - 1) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
batch_logits_data.fill(-Infinity);
for (const eos_token of this.eos_token_id) {
batch_logits_data[eos_token] = 0;
}
}
}
return logits;
}
}
/**
* A LogitsProcessor that suppresses a list of tokens as soon as the `generate` function starts
* generating using `begin_index` tokens. This should ensure that the tokens defined by
* `begin_suppress_tokens` at not sampled at the begining of the generation.
*/
class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
/**
* Create a SuppressTokensAtBeginLogitsProcessor.
* @param {number[]} begin_suppress_tokens The IDs of the tokens to suppress.
* @param {number} begin_index The number of tokens to generate before suppressing tokens.
*/
constructor(begin_suppress_tokens, begin_index) {
super();
this.begin_suppress_tokens = begin_suppress_tokens;
this.begin_index = begin_index;
}
/**
* Apply the BOS token forcing to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with BOS token forcing.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
if (input_ids[i].length === this.begin_index) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
for (const token_id of this.begin_suppress_tokens) {
batch_logits_data[token_id] = -Infinity;
}
}
}
return logits;
}
}
/**
* A LogitsProcessor that handles adding timestamps to generated text.
*/
class WhisperTimeStampLogitsProcessor extends LogitsProcessor {
/**
* Constructs a new WhisperTimeStampLogitsProcessor.
* @param {import('../models/whisper/generation_whisper.js').WhisperGenerationConfig} generate_config The config object passed to the `generate()` method of a transformer model.
* @param {number[]} init_tokens The initial tokens of the input sequence.
*/
constructor(generate_config, init_tokens) {
super();
this.eos_token_id =
Array.isArray(generate_config.eos_token_id)
? generate_config.eos_token_id[0]
: generate_config.eos_token_id;
this.no_timestamps_token_id = generate_config.no_timestamps_token_id;
this.timestamp_begin = this.no_timestamps_token_id + 1;
this.begin_index = init_tokens.length;
if (init_tokens.at(-1) === this.no_timestamps_token_id) {
this.begin_index -= 1;
}
this.max_initial_timestamp_index = generate_config.max_initial_timestamp_index;
}
/**
* Modify the logits to handle timestamp tokens.
* @param {bigint[][]} input_ids The input sequence of tokens.
* @param {Tensor} logits The logits output by the model.
* @returns {Tensor} The modified logits.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
// suppress <|notimestamps|> which is handled by without_timestamps
batch_logits_data[this.no_timestamps_token_id] = -Infinity;
if (input_ids[i].length === this.begin_index - 1) {
batch_logits_data.fill(-Infinity);
batch_logits_data[this.timestamp_begin] = 0;
continue;
}
// timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
const seq = input_ids[i].slice(this.begin_index);
const last_was_timestamp = seq.length >= 1 && seq[seq.length - 1] >= this.timestamp_begin;
const penultimate_was_timestamp = seq.length < 2 || seq[seq.length - 2] >= this.timestamp_begin;
if (last_was_timestamp) {
if (penultimate_was_timestamp) { // has to be non-timestamp
batch_logits_data.subarray(this.timestamp_begin).fill(-Infinity);
} else { // cannot be normal text tokens
batch_logits_data.subarray(0, this.eos_token_id).fill(-Infinity);
}
}
// apply the `max_initial_timestamp` option
if (input_ids[i].length === this.begin_index && this.max_initial_timestamp_index !== null) {
const last_allowed = this.timestamp_begin + this.max_initial_timestamp_index;
batch_logits_data.subarray(last_allowed + 1).fill(-Infinity);
}
// if sum of probability over timestamps is above any other token, sample timestamp
const logprobs = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.log_softmax)(batch_logits_data);
const timestamp_logprob = Math.log(logprobs.subarray(this.timestamp_begin).map(Math.exp).reduce((a, b) => a + b));
const max_text_token_logprob = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.max)(logprobs.subarray(0, this.timestamp_begin))[0];
if (timestamp_logprob > max_text_token_logprob) {
batch_logits_data.subarray(0, this.timestamp_begin).fill(-Infinity);
}
}
return logits;
}
}
/**
* A logits processor that disallows ngrams of a certain size to be repeated.
*/
class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
/**
* Create a NoRepeatNGramLogitsProcessor.
* @param {number} no_repeat_ngram_size The no-repeat-ngram size. All ngrams of this size can only occur once.
*/
constructor(no_repeat_ngram_size) {
super();
this.no_repeat_ngram_size = no_repeat_ngram_size;
}
/**
* Generate n-grams from a sequence of token ids.
* @param {bigint[]} prevInputIds List of previous input ids
* @returns {Map<string, number[]>} Map of generated n-grams
*/
getNgrams(prevInputIds) {
const curLen = prevInputIds.length;
/**@type {number[][]} */
const ngrams = [];
for (let j = 0; j < curLen + 1 - this.no_repeat_ngram_size; ++j) {
const ngram = [];
for (let k = 0; k < this.no_repeat_ngram_size; ++k) {
ngram.push(prevInputIds[j + k]);
}
ngrams.push(ngram.map(Number));
}
/** @type {Map<string, number[]>} */
const generatedNgram = new Map();
for (const ngram of ngrams) {
const prevNgram = ngram.slice(0, ngram.length - 1);
const prevNgramKey = JSON.stringify(prevNgram);
const prevNgramValue = generatedNgram.get(prevNgramKey) ?? [];
prevNgramValue.push(ngram[ngram.length - 1]);
generatedNgram.set(prevNgramKey, prevNgramValue);
}
return generatedNgram;
}
/**
* Generate n-grams from a sequence of token ids.
* @param {Map<string, number[]>} bannedNgrams Map of banned n-grams
* @param {bigint[]} prevInputIds List of previous input ids
* @returns {number[]} Map of generated n-grams
*/
getGeneratedNgrams(bannedNgrams, prevInputIds) {
const ngramIdx = prevInputIds.slice(prevInputIds.length + 1 - this.no_repeat_ngram_size, prevInputIds.length);
const banned = bannedNgrams.get(JSON.stringify(ngramIdx.map(Number))) ?? [];
return banned;
}
/**
* Calculate banned n-gram tokens
* @param {bigint[]} prevInputIds List of previous input ids
* @returns {number[]} Map of generated n-grams
*/
calcBannedNgramTokens(prevInputIds) {
const bannedTokens = [];
if (prevInputIds.length + 1 < this.no_repeat_ngram_size) {
// return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return bannedTokens;
} else {
const generatedNgrams = this.getNgrams(prevInputIds);
const bannedTokens = this.getGeneratedNgrams(generatedNgrams, prevInputIds);
return bannedTokens;
}
}
/**
* Apply the no-repeat-ngram processor to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with no-repeat-ngram processing.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
const bannedTokens = this.calcBannedNgramTokens(input_ids[i]);
for (const token of bannedTokens) {
batch_logits_data[token] = -Infinity;
}
}
return logits;
}
}
/**
* A logits processor that penalises repeated output tokens.
*/
class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
/**
* Create a RepetitionPenaltyLogitsProcessor.
* @param {number} penalty The penalty to apply for repeated tokens.
*/
constructor(penalty) {
super();
this.penalty = penalty;
}
/**
* Apply the repetition penalty to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with repetition penalty processing.
*/
_call(input_ids, logits) {
// Modify the logits corresponding to each element in `input_ids`.
// As a consequence, the logits corresponding to tokens that appear
// many times in the output will be penalised more.
for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
for (const input_id of input_ids[i]) {
const token = Number(input_id);
if (batch_logits_data[token] < 0) {
batch_logits_data[token] *= this.penalty;
} else {
batch_logits_data[token] /= this.penalty;
}
}
}
return logits
}
}
/**
* A logits processor that enforces a minimum number of tokens.
*/
class MinLengthLogitsProcessor extends LogitsProcessor {
/**
* Create a MinLengthLogitsProcessor.
* @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity.
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
*/
constructor(min_length, eos_token_id) {
super();
this.min_length = min_length;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}
/**
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
if (input_ids[i].length < this.min_length) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
for (const eos_token of this.eos_token_id) {
batch_logits_data[eos_token] = -Infinity;
}
}
}
return logits
}
}
/**
* A logits processor that enforces a minimum number of new tokens.
*/
class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
/**
* Create a MinNewTokensLengthLogitsProcessor.
* @param {number} prompt_length_to_skip The input tokens length.
* @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity.
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
*/
constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) {
super();
this.prompt_length_to_skip = prompt_length_to_skip;
this.min_new_tokens = min_new_tokens;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}
/**
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
const new_tokens_length = input_ids[i].length - this.prompt_length_to_skip;
if (new_tokens_length < this.min_new_tokens) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
for (const eos_token of this.eos_token_id) {
batch_logits_data[eos_token] = -Infinity;
}
}
}
return logits
}
}
class NoBadWordsLogitsProcessor extends LogitsProcessor {
/**
* Create a `NoBadWordsLogitsProcessor`.
* @param {number[][]} bad_words_ids List of list of token ids that are not allowed to be generated.
* @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
*/
constructor(bad_words_ids, eos_token_id) {
super();
this.bad_words_ids = bad_words_ids;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}
/**
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
const ids = input_ids[i];
for (const bad_word_ids of this.bad_words_ids) {
// Whether to modify the logits of the last token in the bad word id sequence
let mark = true;
// For each bad word in the list, if the current sequence of input ids ends with this sequence (excluding the last),
// then we set the logits of the last bad word id to -Infinity.
for (let j = 1; j <= bad_word_ids.length - 1 && bad_word_ids.length < ids.length; ++j) {
// NOTE: We use != instead of !== to compare bigint and number
// @ts-ignore
if (bad_word_ids.at(-j - 1) != ids.at(-j)) {
// We have found a mismatch
mark = false;
break;
}
}
if (mark) {
batch_logits_data[bad_word_ids.at(-1)] = -Infinity;
}
}
}
return logits
}
}
/**
* [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
* where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
* correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
* weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
*
* See [the paper](https://arxiv.org/abs/2306.05284) for more information.
*/
class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor {
/**
* Create a `ClassifierFreeGuidanceLogitsProcessor`.
* @param {number} guidance_scale The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
* Higher guidance scale encourages the model to generate samples that are more closely linked to the input
* prompt, usually at the expense of poorer quality.
*/
constructor(guidance_scale) {
super();
if (guidance_scale <= 1) {
throw new Error(
`Require guidance scale >1 to use the classifier free guidance processor, got guidance scale ${guidance_scale}.`
)
}
this.guidance_scale = guidance_scale;
}
/**
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
if (logits.dims[0] !== 2 * input_ids.length) {
throw new Error(
`Logits should have twice the batch size of the input ids, the first half of batches corresponding to ` +
`the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got ` +
`batch size ${logits.dims[0]} for the logits and ${input_ids.length} for the input ids.`
)
}
const unguided_bsz = input_ids.length;
const cond_logits = logits.slice([0, unguided_bsz], null);
const uncond_logits = logits.slice([unguided_bsz, logits.dims[0]], null);
// Merge into uncond_logits (to save memory). This is equivalent to the following:
// scores = uncond_logits + (cond_logits - uncond_logits) * guidance_scale
for (let i = 0; i < uncond_logits.data.length; ++i) {
uncond_logits.data[i] += (cond_logits.data[i] - uncond_logits.data[i]) * this.guidance_scale;
}
return uncond_logits;
}
}
/**
* [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
* that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and [`TopKLogitsWarper`].
*/
class TemperatureLogitsWarper extends LogitsWarper {
/**
* Create a `TemperatureLogitsWarper`.
* @param {number} temperature Strictly positive float value used to modulate the logits distribution.
* A value smaller than `1` decreases randomness (and vice versa), with `0` being equivalent to shifting
* all probability mass to the most likely token.
*/
constructor(temperature) {
super();
if (typeof temperature !== 'number' || temperature <= 0) {
let errorMessage =
`\`temperature\` (=${temperature}) must be a strictly positive float, otherwise your next token scores will be invalid.`;
if (temperature === 0) {
errorMessage += " If you're looking for greedy decoding strategies, set `do_sample=false`."
}
}
this.temperature = temperature;
}
/**
* Apply logit warper.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
const batch_logits_data = /** @type {Float32Array} */(logits.data);
for (let i = 0; i < batch_logits_data.length; ++i) {
batch_logits_data[i] /= this.temperature;
}
return logits;
}
}
/**
* [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
* Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
*/
class TopPLogitsWarper extends LogitsWarper {
/**
* Create a `TopPLogitsWarper`.
* @param {number} top_p If set to < 1, only the smallest set of most probable tokens with
* probabilities that add up to `top_p` or higher are kept for generation.
* @param {Object} options Additional options for the top-p sampling.
* @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value.
* @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered.
*/
constructor(top_p, {
filter_value = -Infinity,
min_tokens_to_keep = 1,
} = {}) {
super();
if (top_p < 0 || top_p > 1.0) {
throw new Error(`\`top_p\` must be a float > 0 and < 1, but is ${top_p}`)
}
if (!Number.isInteger(min_tokens_to_keep) || min_tokens_to_keep < 1) {
throw new Error(`\`min_tokens_to_keep\` must be a positive integer, but is ${min_tokens_to_keep}`)
}
this.top_p = top_p
this.filter_value = filter_value
this.min_tokens_to_keep = min_tokens_to_keep
}
}
/**
* [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
* Often used together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
*/
class TopKLogitsWarper extends LogitsWarper {
/**
* Create a `TopKLogitsWarper`.
* @param {number} top_k If set to > 0, only the top `top_k` tokens are kept for generation.
* @param {Object} options Additional options for the top-k sampling.
* @param {number} [options.filter_value=-Infinity] All filtered values will be set to this float value.
* @param {number} [options.min_tokens_to_keep=1] Minimum number of tokens that cannot be filtered.
*/
constructor(top_k, {
filter_value = -Infinity,
min_tokens_to_keep = 1,
} = {}) {
super();
if (!Number.isInteger(top_k) || top_k < 0) {
throw new Error(`\`top_k\` must be a positive integer, but is ${top_k}`)
}
this.top_k = Math.max(top_k, min_tokens_to_keep)
this.filter_value = filter_value
}
}
/***/ }),
/***/ "./src/generation/logits_sampler.js":
/*!******************************************!*\
!*** ./src/generation/logits_sampler.js ***!
\******************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ LogitsSampler: () => (/* binding */ LogitsSampler)
/* harmony export */ });
/* harmony import */ var _utils_generic_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ../utils/generic.js */ "./src/utils/generic.js");
/* harmony import */ var _utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ../utils/tensor.js */ "./src/utils/tensor.js");
/* harmony import */ var _utils_maths_js__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! ../utils/maths.js */ "./src/utils/maths.js");
/* harmony import */ var _generation_configuration_utils_js__WEBPACK_IMPORTED_MODULE_3__ = __webpack_require__(/*! ../generation/configuration_utils.js */ "./src/generation/configuration_utils.js");
/**
* @module generation/logits_sampler
*/
/**
* Sampler is a base class for all sampling methods used for text generation.
*/
class LogitsSampler extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_0__.Callable {
/**
* Creates a new Sampler object with the specified generation config.
* @param {GenerationConfig} generation_config The generation config.
*/
constructor(generation_config) {
super();
this.generation_config = generation_config;
}
/**
* Executes the sampler, using the specified logits.
* @param {Tensor} logits
* @returns {Promise<[bigint, number][]>}
*/
async _call(logits) {
// Sample from logits, of dims [batch, sequence_length, vocab_size].
// If index is specified, sample from [batch, index, vocab_size].
return this.sample(logits);
}
/**
* Abstract method for sampling the logits.
* @param {Tensor} logits
* @throws {Error} If not implemented in subclass.
* @returns {Promise<[bigint, number][]>}
*/
async sample(logits) {
throw Error("sample should be implemented in subclasses.")
}
/**
* Returns the specified logits as an array, with temperature applied.
* @param {Tensor} logits
* @param {number} index
* @returns {Float32Array}
*/
getLogits(logits, index) {
let vocabSize = logits.dims.at(-1);
let logs = /** @type {Float32Array} */(logits.data);
if (index === -1) {
logs = logs.slice(-vocabSize);
} else {
let startIndex = index * vocabSize;
logs = logs.slice(startIndex, startIndex + vocabSize);
}
return logs;
}
/**
* Selects an item randomly based on the specified probabilities.
* @param {import("../transformers.js").DataArray} probabilities An array of probabilities to use for selection.
* @returns {number} The index of the selected item.
*/
randomSelect(probabilities) {
// Return index of chosen item
let sumProbabilities = 0;
for (let i = 0; i < probabilities.length; ++i) {
sumProbabilities += probabilities[i];
}
let r = Math.random() * sumProbabilities;
for (let i = 0; i < probabilities.length; ++i) {
r -= probabilities[i];
if (r <= 0) {
return i;
}
}
return 0; // return first (most probable) as a fallback
}
/**
* Returns a Sampler object based on the specified options.
* @param {GenerationConfig} generation_config An object containing options for the sampler.
* @returns {LogitsSampler} A Sampler object.
*/
static getSampler(generation_config) {
// - *greedy decoding*: `num_beams=1` and `do_sample=False`
// - *contrastive search*: `penalty_alpha>0` and `top_k>1`
// - *multinomial sampling*: `num_beams=1` and `do_sample=True`
// - *beam-search decoding*: `num_beams>1` and `do_sample=False`
// - *beam-search multinomial sampling*: `num_beams>1` and `do_sample=True`
// - *diverse beam-search decoding*: `num_beams>1` and `num_beam_groups>1`
// - *constrained beam-search decoding*: `constraints!=None` or `force_words_ids!=None`
// NOTE: beam search is implemented directly into the generation function
if (generation_config.do_sample) {
return new MultinomialSampler(generation_config);
} else if (generation_config.num_beams > 1) {
return new BeamSearchSampler(generation_config);
} else {
if (generation_config.num_return_sequences > 1) {
throw Error(`num_return_sequences has to be 1 when doing greedy search, but is ${generation_config.num_return_sequences}.`)
}
return new GreedySampler(generation_config);
}
}
}
/**
* Class representing a Greedy Sampler.
*/
class GreedySampler extends LogitsSampler {
/**
* Sample the maximum probability of a given logits tensor.
* @param {Tensor} logits
* @returns {Promise<[bigint, number][]>} An array with a single tuple, containing the index of the maximum value and a meaningless score (since this is a greedy search).
*/
async sample(logits) {
// NOTE: no need to do log_softmax here since we only take the maximum
const argmax = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.max)(logits.data)[1];
// Note: score is meaningless in this context, since we are performing
// greedy search (p = 1 => log(p) = 0)
return [
[BigInt(argmax), 0]
];
}
}
/**
* Class representing a MultinomialSampler.
*/
class MultinomialSampler extends LogitsSampler {
/**
* Sample from the logits.
* @param {Tensor} logits
* @returns {Promise<[bigint, number][]>}
*/
async sample(logits) {
let k = logits.dims.at(-1); // defaults to vocab size
if (this.generation_config.top_k > 0) {
k = Math.min(this.generation_config.top_k, k);
}
// Get top k tokens
const [v, i] = await (0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__.topk)(logits, k);
// Compute softmax over logits
const probabilities = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.softmax)(/** @type {Float32Array} */(v.data));
return Array.from({ length: this.generation_config.num_beams }, () => {
const sampledIndex = this.randomSelect(probabilities);
return [
i.data[sampledIndex], // token id
Math.log(probabilities[sampledIndex]), // score
];
});
}
}
/**
* Class representing a BeamSearchSampler.
*/
class BeamSearchSampler extends LogitsSampler {
/**
* Sample from the logits.
* @param {Tensor} logits
* @returns {Promise<[bigint, number][]>}
*/
async sample(logits) {
let k = logits.dims.at(-1); // defaults to vocab size
if (this.generation_config.top_k > 0) {
k = Math.min(this.generation_config.top_k, k);
}
// Get top k tokens
const [v, i] = await (0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_1__.topk)(logits, k);
// Compute softmax over logits
const probabilities = (0,_utils_maths_js__WEBPACK_IMPORTED_MODULE_2__.softmax)(/** @type {Float32Array} */(v.data));
return Array.from({ length: this.generation_config.num_beams }, (_, x) => {
return [
i.data[x], // token id
Math.log(probabilities[x]), // score
];
});
}
}
/***/ }),
/***/ "./src/generation/stopping_criteria.js":
/*!*********************************************!*\
!*** ./src/generation/stopping_criteria.js ***!
\*********************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ EosTokenCriteria: () => (/* binding */ EosTokenCriteria),
/* harmony export */ InterruptableStoppingCriteria: () => (/* binding */ InterruptableStoppingCriteria),
/* harmony export */ MaxLengthCriteria: () => (/* binding */ MaxLengthCriteria),
/* harmony export */ StoppingCriteria: () => (/* binding */ StoppingCriteria),
/* harmony export */ StoppingCriteriaList: () => (/* binding */ StoppingCriteriaList)
/* harmony export */ });
/* harmony import */ var _utils_generic_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ../utils/generic.js */ "./src/utils/generic.js");
/**
* @module generation/stopping_criteria
*/
// NOTE:
// Stopping Criteria returns a list of `batch_size` booleans, indicating whether each sequence in the batch should be stopped.
/**
* Abstract base class for all stopping criteria that can be applied during generation.
*/
class StoppingCriteria extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_0__.Callable {
/**
*
* @param {number[][]} input_ids (`number[][]` of shape `(batch_size, sequence_length)`):
* Indices of input sequence tokens in the vocabulary.
* @param {number[][]} scores scores (`number[][]` of shape `(batch_size, config.vocab_size)`):
* Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
* or scores for each vocabulary token after SoftMax.
* @returns {boolean[]} A list of booleans indicating whether each sequence should be stopped.
*/
_call(input_ids, scores) {
throw Error("StoppingCriteria needs to be subclassed");
}
}
/**
*/
class StoppingCriteriaList extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_0__.Callable {
/**
* Constructs a new instance of `StoppingCriteriaList`.
*/
constructor() {
super();
this.criteria = [];
}
/**
* Adds a new stopping criterion to the list.
*
* @param {StoppingCriteria} item The stopping criterion to add.
*/
push(item) {
this.criteria.push(item);
}
/**
* Adds multiple stopping criteria to the list.
*
* @param {StoppingCriteria|StoppingCriteriaList|StoppingCriteria[]} items The stopping criteria to add.
*/
extend(items) {
if (items instanceof StoppingCriteriaList) {
items = items.criteria;
} else if (items instanceof StoppingCriteria) {
items = [items];
}
this.criteria.push(...items);
}
_call(input_ids, scores) {
const is_done = new Array(input_ids.length).fill(false);
for (const criterion of this.criteria) {
const criterion_done = criterion(input_ids, scores);
for (let i = 0; i < is_done.length; ++i) {
is_done[i] ||= criterion_done[i];
}
}
return is_done;
}
[Symbol.iterator]() {
return this.criteria.values();
}
}
/**
* This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`.
* Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens.
*/
class MaxLengthCriteria extends StoppingCriteria {
/**
*
* @param {number} max_length The maximum length that the output sequence can have in number of tokens.
* @param {number} [max_position_embeddings=null] The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
*/
constructor(max_length, max_position_embeddings = null) {
super();
this.max_length = max_length;
this.max_position_embeddings = max_position_embeddings;
}
_call(input_ids) {
return input_ids.map(ids => ids.length >= this.max_length);
}
}
// TODO: add MaxTimeCriteria
/**
* This class can be used to stop generation whenever the "end-of-sequence" token is generated.
* By default, it uses the `model.generation_config.eos_token_id`.
*/
class EosTokenCriteria extends StoppingCriteria {
/**
*
* @param {number|number[]} eos_token_id The id of the *end-of-sequence* token.
* Optionally, use a list to set multiple *end-of-sequence* tokens.
*/
constructor(eos_token_id) {
super();
if (!Array.isArray(eos_token_id)) {
eos_token_id = [eos_token_id];
}
this.eos_token_id = eos_token_id;
}
/**
*
* @param {number[][]} input_ids
* @param {number[][]} scores
* @returns {boolean[]}
*/
_call(input_ids, scores) {
return input_ids.map(ids => {
const last = ids.at(-1);
// NOTE: We use == instead of === to allow for number/bigint comparison
return this.eos_token_id.some(eos_id => last == eos_id);
});
}
}
/**
* This class can be used to stop generation whenever the user interrupts the process.
*/
class InterruptableStoppingCriteria extends StoppingCriteria {
constructor() {
super();
this.interrupted = false;
}
interrupt() {
this.interrupted = true;
}
reset() {
this.interrupted = false;
}
_call(input_ids, scores) {
return new Array(input_ids.length).fill(this.interrupted);
}
}
/***/ }),
/***/ "./src/generation/streamers.js":
/*!*************************************!*\
!*** ./src/generation/streamers.js ***!
\*************************************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ BaseStreamer: () => (/* binding */ BaseStreamer),
/* harmony export */ TextStreamer: () => (/* binding */ TextStreamer),
/* harmony export */ WhisperTextStreamer: () => (/* binding */ WhisperTextStreamer)
/* harmony export */ });
/* harmony import */ var _utils_core_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ../utils/core.js */ "./src/utils/core.js");
/* harmony import */ var _tokenizers_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ../tokenizers.js */ "./src/tokenizers.js");
/* harmony import */ var _env_js__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! ../env.js */ "./src/env.js");
/**
* @module generation/streamers
*/
class BaseStreamer {
/**
* Function that is called by `.generate()` to push new tokens
* @param {bigint[][]} value
*/
put(value) {
throw Error('Not implemented');
}
/**
* Function that is called by `.generate()` to signal the end of generation
*/
end() {
throw Error('Not implemented');
}
}
const stdout_write = _env_js__WEBPACK_IMPORTED_MODULE_2__.apis.IS_PROCESS_AVAILABLE
? x => process.stdout.write(x)
: x => console.log(x);
/**
* Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
*/
class TextStreamer extends BaseStreamer {
/**
*
* @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
*/
constructor(tokenizer, {
skip_prompt = false,
callback_function = null,
token_callback_function = null,
decode_kwargs = {},
...kwargs
} = {}) {
super();
this.tokenizer = tokenizer;
this.skip_prompt = skip_prompt;
this.callback_function = callback_function ?? stdout_write;
this.token_callback_function = token_callback_function;
this.decode_kwargs = { ...decode_kwargs, ...kwargs };
// variables used in the streaming process
this.token_cache = [];
this.print_len = 0;
this.next_tokens_are_prompt = true;
}
/**
* Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
* @param {bigint[][]} value
*/
put(value) {
if (value.length > 1) {
throw Error('TextStreamer only supports batch size of 1');
}
if (this.skip_prompt && this.next_tokens_are_prompt) {
this.next_tokens_are_prompt = false;
return;
}
const tokens = value[0];
this.token_callback_function?.(tokens)
// Add the new token to the cache and decodes the entire thing.
this.token_cache = (0,_utils_core_js__WEBPACK_IMPORTED_MODULE_0__.mergeArrays)(this.token_cache, tokens);
const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs);
let printable_text;
if (text.endsWith('\n')) {
// After the symbol for a new line, we flush the cache.
printable_text = text.slice(this.print_len);
this.token_cache = [];
this.print_len = 0;
} else if (text.length > 0 && (0,_tokenizers_js__WEBPACK_IMPORTED_MODULE_1__.is_chinese_char)(text.charCodeAt(text.length - 1))) {
// If the last token is a CJK character, we print the characters.
printable_text = text.slice(this.print_len);
this.print_len += printable_text.length;
} else {
// Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
// which may change with the subsequent token -- there are probably smarter ways to do this!)
printable_text = text.slice(this.print_len, text.lastIndexOf(' ') + 1);
this.print_len += printable_text.length;
}
this.on_finalized_text(printable_text, false);
}
/**
* Flushes any remaining cache and prints a newline to stdout.
*/
end() {
let printable_text;
if (this.token_cache.length > 0) {
const text = this.tokenizer.decode(this.token_cache, this.decode_kwargs);
printable_text = text.slice(this.print_len);
this.token_cache = [];
this.print_len = 0;
} else {
printable_text = '';
}
this.next_tokens_are_prompt = true;
this.on_finalized_text(printable_text, true);
}
/**
* Prints the new text to stdout. If the stream is ending, also prints a newline.
* @param {string} text
* @param {boolean} stream_end
*/
on_finalized_text(text, stream_end) {
if (text.length > 0) {
this.callback_function?.(text);
}
if (stream_end && this.callback_function === stdout_write && _env_js__WEBPACK_IMPORTED_MODULE_2__.apis.IS_PROCESS_AVAILABLE) {
this.callback_function?.('\n');
}
}
}
/**
* Utility class to handle streaming of tokens generated by whisper speech-to-text models.
* Callback functions are invoked when each of the following events occur:
* - A new chunk starts (on_chunk_start)
* - A new token is generated (callback_function)
* - A chunk ends (on_chunk_end)
* - The stream is finalized (on_finalize)
*/
class WhisperTextStreamer extends TextStreamer {
/**
* @param {import('../tokenizers.js').WhisperTokenizer} tokenizer
* @param {Object} options
* @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
* @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
* @param {function(string): void} [options.token_callback_function=null] Function to call when a new token is generated
* @param {function(number): void} [options.on_chunk_start=null] Function to call when a new chunk starts
* @param {function(number): void} [options.on_chunk_end=null] Function to call when a chunk ends
* @param {function(): void} [options.on_finalize=null] Function to call when the stream is finalized
* @param {number} [options.time_precision=0.02] Precision of the timestamps
* @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding
* @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method
*/
constructor(tokenizer, {
skip_prompt = false,
callback_function = null,
token_callback_function = null,
on_chunk_start = null,
on_chunk_end = null,
on_finalize = null,
time_precision = 0.02,
skip_special_tokens = true,
decode_kwargs = {},
} = {}) {
super(tokenizer, {
skip_prompt,
callback_function,
token_callback_function,
decode_kwargs: { skip_special_tokens, ...decode_kwargs },
});
this.timestamp_begin = tokenizer.timestamp_begin;
this.on_chunk_start = on_chunk_start;
this.on_chunk_end = on_chunk_end;
this.on_finalize = on_finalize;
this.time_precision = time_precision;
this.waiting_for_timestamp = false;
}
/**
* @param {bigint[][]} value
*/
put(value) {
if (value.length > 1) {
throw Error('WhisperTextStreamer only supports batch size of 1');
}
const tokens = value[0];
// Check if the token is a timestamp
if (tokens.length === 1) {
const offset = Number(tokens[0]) - this.timestamp_begin;
if (offset >= 0) {
const time = offset * this.time_precision;
if (this.waiting_for_timestamp) {
this.on_chunk_end?.(time);
} else {
this.on_chunk_start?.(time);
}
this.waiting_for_timestamp = !this.waiting_for_timestamp; // Toggle
value = [[]]; // Skip timestamp
}
}
return super.put(value);
}
end() {
super.end();
this.on_finalize?.();
}
}
/***/ }),
/***/ "./src/models.js":
/*!***********************!*\
!*** ./src/models.js ***!
\***********************/
/***/ ((__unused_webpack___webpack_module__, __webpack_exports__, __webpack_require__) => {
__webpack_require__.r(__webpack_exports__);
/* harmony export */ __webpack_require__.d(__webpack_exports__, {
/* harmony export */ ASTForAudioClassification: () => (/* binding */ ASTForAudioClassification),
/* harmony export */ ASTModel: () => (/* binding */ ASTModel),
/* harmony export */ ASTPreTrainedModel: () => (/* binding */ ASTPreTrainedModel),
/* harmony export */ AlbertForMaskedLM: () => (/* binding */ AlbertForMaskedLM),
/* harmony export */ AlbertForQuestionAnswering: () => (/* binding */ AlbertForQuestionAnswering),
/* harmony export */ AlbertForSequenceClassification: () => (/* binding */ AlbertForSequenceClassification),
/* harmony export */ AlbertModel: () => (/* binding */ AlbertModel),
/* harmony export */ AlbertPreTrainedModel: () => (/* binding */ AlbertPreTrainedModel),
/* harmony export */ AutoModel: () => (/* binding */ AutoModel),
/* harmony export */ AutoModelForAudioClassification: () => (/* binding */ AutoModelForAudioClassification),
/* harmony export */ AutoModelForAudioFrameClassification: () => (/* binding */ AutoModelForAudioFrameClassification),
/* harmony export */ AutoModelForCTC: () => (/* binding */ AutoModelForCTC),
/* harmony export */ AutoModelForCausalLM: () => (/* binding */ AutoModelForCausalLM),
/* harmony export */ AutoModelForDepthEstimation: () => (/* binding */ AutoModelForDepthEstimation),
/* harmony export */ AutoModelForDocumentQuestionAnswering: () => (/* binding */ AutoModelForDocumentQuestionAnswering),
/* harmony export */ AutoModelForImageClassification: () => (/* binding */ AutoModelForImageClassification),
/* harmony export */ AutoModelForImageFeatureExtraction: () => (/* binding */ AutoModelForImageFeatureExtraction),
/* harmony export */ AutoModelForImageMatting: () => (/* binding */ AutoModelForImageMatting),
/* harmony export */ AutoModelForImageSegmentation: () => (/* binding */ AutoModelForImageSegmentation),
/* harmony export */ AutoModelForImageToImage: () => (/* binding */ AutoModelForImageToImage),
/* harmony export */ AutoModelForMaskGeneration: () => (/* binding */ AutoModelForMaskGeneration),
/* harmony export */ AutoModelForMaskedLM: () => (/* binding */ AutoModelForMaskedLM),
/* harmony export */ AutoModelForNormalEstimation: () => (/* binding */ AutoModelForNormalEstimation),
/* harmony export */ AutoModelForObjectDetection: () => (/* binding */ AutoModelForObjectDetection),
/* harmony export */ AutoModelForPoseEstimation: () => (/* binding */ AutoModelForPoseEstimation),
/* harmony export */ AutoModelForQuestionAnswering: () => (/* binding */ AutoModelForQuestionAnswering),
/* harmony export */ AutoModelForSemanticSegmentation: () => (/* binding */ AutoModelForSemanticSegmentation),
/* harmony export */ AutoModelForSeq2SeqLM: () => (/* binding */ AutoModelForSeq2SeqLM),
/* harmony export */ AutoModelForSequenceClassification: () => (/* binding */ AutoModelForSequenceClassification),
/* harmony export */ AutoModelForSpeechSeq2Seq: () => (/* binding */ AutoModelForSpeechSeq2Seq),
/* harmony export */ AutoModelForTextToSpectrogram: () => (/* binding */ AutoModelForTextToSpectrogram),
/* harmony export */ AutoModelForTextToWaveform: () => (/* binding */ AutoModelForTextToWaveform),
/* harmony export */ AutoModelForTokenClassification: () => (/* binding */ AutoModelForTokenClassification),
/* harmony export */ AutoModelForUniversalSegmentation: () => (/* binding */ AutoModelForUniversalSegmentation),
/* harmony export */ AutoModelForVision2Seq: () => (/* binding */ AutoModelForVision2Seq),
/* harmony export */ AutoModelForXVector: () => (/* binding */ AutoModelForXVector),
/* harmony export */ AutoModelForZeroShotObjectDetection: () => (/* binding */ AutoModelForZeroShotObjectDetection),
/* harmony export */ BartForConditionalGeneration: () => (/* binding */ BartForConditionalGeneration),
/* harmony export */ BartForSequenceClassification: () => (/* binding */ BartForSequenceClassification),
/* harmony export */ BartModel: () => (/* binding */ BartModel),
/* harmony export */ BartPretrainedModel: () => (/* binding */ BartPretrainedModel),
/* harmony export */ BaseModelOutput: () => (/* binding */ BaseModelOutput),
/* harmony export */ BeitForImageClassification: () => (/* binding */ BeitForImageClassification),
/* harmony export */ BeitModel: () => (/* binding */ BeitModel),
/* harmony export */ BeitPreTrainedModel: () => (/* binding */ BeitPreTrainedModel),
/* harmony export */ BertForMaskedLM: () => (/* binding */ BertForMaskedLM),
/* harmony export */ BertForQuestionAnswering: () => (/* binding */ BertForQuestionAnswering),
/* harmony export */ BertForSequenceClassification: () => (/* binding */ BertForSequenceClassification),
/* harmony export */ BertForTokenClassification: () => (/* binding */ BertForTokenClassification),
/* harmony export */ BertModel: () => (/* binding */ BertModel),
/* harmony export */ BertPreTrainedModel: () => (/* binding */ BertPreTrainedModel),
/* harmony export */ BlenderbotForConditionalGeneration: () => (/* binding */ BlenderbotForConditionalGeneration),
/* harmony export */ BlenderbotModel: () => (/* binding */ BlenderbotModel),
/* harmony export */ BlenderbotPreTrainedModel: () => (/* binding */ BlenderbotPreTrainedModel),
/* harmony export */ BlenderbotSmallForConditionalGeneration: () => (/* binding */ BlenderbotSmallForConditionalGeneration),
/* harmony export */ BlenderbotSmallModel: () => (/* binding */ BlenderbotSmallModel),
/* harmony export */ BlenderbotSmallPreTrainedModel: () => (/* binding */ BlenderbotSmallPreTrainedModel),
/* harmony export */ BloomForCausalLM: () => (/* binding */ BloomForCausalLM),
/* harmony export */ BloomModel: () => (/* binding */ BloomModel),
/* harmony export */ BloomPreTrainedModel: () => (/* binding */ BloomPreTrainedModel),
/* harmony export */ CLIPModel: () => (/* binding */ CLIPModel),
/* harmony export */ CLIPPreTrainedModel: () => (/* binding */ CLIPPreTrainedModel),
/* harmony export */ CLIPSegForImageSegmentation: () => (/* binding */ CLIPSegForImageSegmentation),
/* harmony export */ CLIPSegModel: () => (/* binding */ CLIPSegModel),
/* harmony export */ CLIPSegPreTrainedModel: () => (/* binding */ CLIPSegPreTrainedModel),
/* harmony export */ CLIPTextModel: () => (/* binding */ CLIPTextModel),
/* harmony export */ CLIPTextModelWithProjection: () => (/* binding */ CLIPTextModelWithProjection),
/* harmony export */ CLIPVisionModel: () => (/* binding */ CLIPVisionModel),
/* harmony export */ CLIPVisionModelWithProjection: () => (/* binding */ CLIPVisionModelWithProjection),
/* harmony export */ CamembertForMaskedLM: () => (/* binding */ CamembertForMaskedLM),
/* harmony export */ CamembertForQuestionAnswering: () => (/* binding */ CamembertForQuestionAnswering),
/* harmony export */ CamembertForSequenceClassification: () => (/* binding */ CamembertForSequenceClassification),
/* harmony export */ CamembertForTokenClassification: () => (/* binding */ CamembertForTokenClassification),
/* harmony export */ CamembertModel: () => (/* binding */ CamembertModel),
/* harmony export */ CamembertPreTrainedModel: () => (/* binding */ CamembertPreTrainedModel),
/* harmony export */ CausalLMOutput: () => (/* binding */ CausalLMOutput),
/* harmony export */ CausalLMOutputWithPast: () => (/* binding */ CausalLMOutputWithPast),
/* harmony export */ ChineseCLIPModel: () => (/* binding */ ChineseCLIPModel),
/* harmony export */ ChineseCLIPPreTrainedModel: () => (/* binding */ ChineseCLIPPreTrainedModel),
/* harmony export */ ClapAudioModelWithProjection: () => (/* binding */ ClapAudioModelWithProjection),
/* harmony export */ ClapModel: () => (/* binding */ ClapModel),
/* harmony export */ ClapPreTrainedModel: () => (/* binding */ ClapPreTrainedModel),
/* harmony export */ ClapTextModelWithProjection: () => (/* binding */ ClapTextModelWithProjection),
/* harmony export */ CodeGenForCausalLM: () => (/* binding */ CodeGenForCausalLM),
/* harmony export */ CodeGenModel: () => (/* binding */ CodeGenModel),
/* harmony export */ CodeGenPreTrainedModel: () => (/* binding */ CodeGenPreTrainedModel),
/* harmony export */ CohereForCausalLM: () => (/* binding */ CohereForCausalLM),
/* harmony export */ CohereModel: () => (/* binding */ CohereModel),
/* harmony export */ CoherePreTrainedModel: () => (/* binding */ CoherePreTrainedModel),
/* harmony export */ ConvBertForMaskedLM: () => (/* binding */ ConvBertForMaskedLM),
/* harmony export */ ConvBertForQuestionAnswering: () => (/* binding */ ConvBertForQuestionAnswering),
/* harmony export */ ConvBertForSequenceClassification: () => (/* binding */ ConvBertForSequenceClassification),
/* harmony export */ ConvBertForTokenClassification: () => (/* binding */ ConvBertForTokenClassification),
/* harmony export */ ConvBertModel: () => (/* binding */ ConvBertModel),
/* harmony export */ ConvBertPreTrainedModel: () => (/* binding */ ConvBertPreTrainedModel),
/* harmony export */ ConvNextForImageClassification: () => (/* binding */ ConvNextForImageClassification),
/* harmony export */ ConvNextModel: () => (/* binding */ ConvNextModel),
/* harmony export */ ConvNextPreTrainedModel: () => (/* binding */ ConvNextPreTrainedModel),
/* harmony export */ ConvNextV2ForImageClassification: () => (/* binding */ ConvNextV2ForImageClassification),
/* harmony export */ ConvNextV2Model: () => (/* binding */ ConvNextV2Model),
/* harmony export */ ConvNextV2PreTrainedModel: () => (/* binding */ ConvNextV2PreTrainedModel),
/* harmony export */ DPTForDepthEstimation: () => (/* binding */ DPTForDepthEstimation),
/* harmony export */ DPTModel: () => (/* binding */ DPTModel),
/* harmony export */ DPTPreTrainedModel: () => (/* binding */ DPTPreTrainedModel),
/* harmony export */ DebertaForMaskedLM: () => (/* binding */ DebertaForMaskedLM),
/* harmony export */ DebertaForQuestionAnswering: () => (/* binding */ DebertaForQuestionAnswering),
/* harmony export */ DebertaForSequenceClassification: () => (/* binding */ DebertaForSequenceClassification),
/* harmony export */ DebertaForTokenClassification: () => (/* binding */ DebertaForTokenClassification),
/* harmony export */ DebertaModel: () => (/* binding */ DebertaModel),
/* harmony export */ DebertaPreTrainedModel: () => (/* binding */ DebertaPreTrainedModel),
/* harmony export */ DebertaV2ForMaskedLM: () => (/* binding */ DebertaV2ForMaskedLM),
/* harmony export */ DebertaV2ForQuestionAnswering: () => (/* binding */ DebertaV2ForQuestionAnswering),
/* harmony export */ DebertaV2ForSequenceClassification: () => (/* binding */ DebertaV2ForSequenceClassification),
/* harmony export */ DebertaV2ForTokenClassification: () => (/* binding */ DebertaV2ForTokenClassification),
/* harmony export */ DebertaV2Model: () => (/* binding */ DebertaV2Model),
/* harmony export */ DebertaV2PreTrainedModel: () => (/* binding */ DebertaV2PreTrainedModel),
/* harmony export */ DecisionTransformerModel: () => (/* binding */ DecisionTransformerModel),
/* harmony export */ DecisionTransformerPreTrainedModel: () => (/* binding */ DecisionTransformerPreTrainedModel),
/* harmony export */ DeiTForImageClassification: () => (/* binding */ DeiTForImageClassification),
/* harmony export */ DeiTModel: () => (/* binding */ DeiTModel),
/* harmony export */ DeiTPreTrainedModel: () => (/* binding */ DeiTPreTrainedModel),
/* harmony export */ DepthAnythingForDepthEstimation: () => (/* binding */ DepthAnythingForDepthEstimation),
/* harmony export */ DepthAnythingPreTrainedModel: () => (/* binding */ DepthAnythingPreTrainedModel),
/* harmony export */ DepthProForDepthEstimation: () => (/* binding */ DepthProForDepthEstimation),
/* harmony export */ DepthProPreTrainedModel: () => (/* binding */ DepthProPreTrainedModel),
/* harmony export */ DetrForObjectDetection: () => (/* binding */ DetrForObjectDetection),
/* harmony export */ DetrForSegmentation: () => (/* binding */ DetrForSegmentation),
/* harmony export */ DetrModel: () => (/* binding */ DetrModel),
/* harmony export */ DetrObjectDetectionOutput: () => (/* binding */ DetrObjectDetectionOutput),
/* harmony export */ DetrPreTrainedModel: () => (/* binding */ DetrPreTrainedModel),
/* harmony export */ DetrSegmentationOutput: () => (/* binding */ DetrSegmentationOutput),
/* harmony export */ Dinov2ForImageClassification: () => (/* binding */ Dinov2ForImageClassification),
/* harmony export */ Dinov2Model: () => (/* binding */ Dinov2Model),
/* harmony export */ Dinov2PreTrainedModel: () => (/* binding */ Dinov2PreTrainedModel),
/* harmony export */ DistilBertForMaskedLM: () => (/* binding */ DistilBertForMaskedLM),
/* harmony export */ DistilBertForQuestionAnswering: () => (/* binding */ DistilBertForQuestionAnswering),
/* harmony export */ DistilBertForSequenceClassification: () => (/* binding */ DistilBertForSequenceClassification),
/* harmony export */ DistilBertForTokenClassification: () => (/* binding */ DistilBertForTokenClassification),
/* harmony export */ DistilBertModel: () => (/* binding */ DistilBertModel),
/* harmony export */ DistilBertPreTrainedModel: () => (/* binding */ DistilBertPreTrainedModel),
/* harmony export */ DonutSwinModel: () => (/* binding */ DonutSwinModel),
/* harmony export */ DonutSwinPreTrainedModel: () => (/* binding */ DonutSwinPreTrainedModel),
/* harmony export */ EfficientNetForImageClassification: () => (/* binding */ EfficientNetForImageClassification),
/* harmony export */ EfficientNetModel: () => (/* binding */ EfficientNetModel),
/* harmony export */ EfficientNetPreTrainedModel: () => (/* binding */ EfficientNetPreTrainedModel),
/* harmony export */ ElectraForMaskedLM: () => (/* binding */ ElectraForMaskedLM),
/* harmony export */ ElectraForQuestionAnswering: () => (/* binding */ ElectraForQuestionAnswering),
/* harmony export */ ElectraForSequenceClassification: () => (/* binding */ ElectraForSequenceClassification),
/* harmony export */ ElectraForTokenClassification: () => (/* binding */ ElectraForTokenClassification),
/* harmony export */ ElectraModel: () => (/* binding */ ElectraModel),
/* harmony export */ ElectraPreTrainedModel: () => (/* binding */ ElectraPreTrainedModel),
/* harmony export */ EsmForMaskedLM: () => (/* binding */ EsmForMaskedLM),
/* harmony export */ EsmForSequenceClassification: () => (/* binding */ EsmForSequenceClassification),
/* harmony export */ EsmForTokenClassification: () => (/* binding */ EsmForTokenClassification),
/* harmony export */ EsmModel: () => (/* binding */ EsmModel),
/* harmony export */ EsmPreTrainedModel: () => (/* binding */ EsmPreTrainedModel),
/* harmony export */ FalconForCausalLM: () => (/* binding */ FalconForCausalLM),
/* harmony export */ FalconModel: () => (/* binding */ FalconModel),
/* harmony export */ FalconPreTrainedModel: () => (/* binding */ FalconPreTrainedModel),
/* harmony export */ FastViTForImageClassification: () => (/* binding */ FastViTForImageClassification),
/* harmony export */ FastViTModel: () => (/* binding */ FastViTModel),
/* harmony export */ FastViTPreTrainedModel: () => (/* binding */ FastViTPreTrainedModel),
/* harmony export */ Florence2ForConditionalGeneration: () => (/* binding */ Florence2ForConditionalGeneration),
/* harmony export */ Florence2PreTrainedModel: () => (/* binding */ Florence2PreTrainedModel),
/* harmony export */ GLPNForDepthEstimation: () => (/* binding */ GLPNForDepthEstimation),
/* harmony export */ GLPNModel: () => (/* binding */ GLPNModel),
/* harmony export */ GLPNPreTrainedModel: () => (/* binding */ GLPNPreTrainedModel),
/* harmony export */ GPT2LMHeadModel: () => (/* binding */ GPT2LMHeadModel),
/* harmony export */ GPT2Model: () => (/* binding */ GPT2Model),
/* harmony export */ GPT2PreTrainedModel: () => (/* binding */ GPT2PreTrainedModel),
/* harmony export */ GPTBigCodeForCausalLM: () => (/* binding */ GPTBigCodeForCausalLM),
/* harmony export */ GPTBigCodeModel: () => (/* binding */ GPTBigCodeModel),
/* harmony export */ GPTBigCodePreTrainedModel: () => (/* binding */ GPTBigCodePreTrainedModel),
/* harmony export */ GPTJForCausalLM: () => (/* binding */ GPTJForCausalLM),
/* harmony export */ GPTJModel: () => (/* binding */ GPTJModel),
/* harmony export */ GPTJPreTrainedModel: () => (/* binding */ GPTJPreTrainedModel),
/* harmony export */ GPTNeoForCausalLM: () => (/* binding */ GPTNeoForCausalLM),
/* harmony export */ GPTNeoModel: () => (/* binding */ GPTNeoModel),
/* harmony export */ GPTNeoPreTrainedModel: () => (/* binding */ GPTNeoPreTrainedModel),
/* harmony export */ GPTNeoXForCausalLM: () => (/* binding */ GPTNeoXForCausalLM),
/* harmony export */ GPTNeoXModel: () => (/* binding */ GPTNeoXModel),
/* harmony export */ GPTNeoXPreTrainedModel: () => (/* binding */ GPTNeoXPreTrainedModel),
/* harmony export */ Gemma2ForCausalLM: () => (/* binding */ Gemma2ForCausalLM),
/* harmony export */ Gemma2Model: () => (/* binding */ Gemma2Model),
/* harmony export */ Gemma2PreTrainedModel: () => (/* binding */ Gemma2PreTrainedModel),
/* harmony export */ GemmaForCausalLM: () => (/* binding */ GemmaForCausalLM),
/* harmony export */ GemmaModel: () => (/* binding */ GemmaModel),
/* harmony export */ GemmaPreTrainedModel: () => (/* binding */ GemmaPreTrainedModel),
/* harmony export */ GraniteForCausalLM: () => (/* binding */ GraniteForCausalLM),
/* harmony export */ GraniteModel: () => (/* binding */ GraniteModel),
/* harmony export */ GranitePreTrainedModel: () => (/* binding */ GranitePreTrainedModel),
/* harmony export */ GroupViTModel: () => (/* binding */ GroupViTModel),
/* harmony export */ GroupViTPreTrainedModel: () => (/* binding */ GroupViTPreTrainedModel),
/* harmony export */ HieraForImageClassification: () => (/* binding */ HieraForImageClassification),
/* harmony export */ HieraModel: () => (/* binding */ HieraModel),
/* harmony export */ HieraPreTrainedModel: () => (/* binding */ HieraPreTrainedModel),
/* harmony export */ HubertForCTC: () => (/* binding */ HubertForCTC),
/* harmony export */ HubertForSequenceClassification: () => (/* binding */ HubertForSequenceClassification),
/* harmony export */ HubertModel: () => (/* binding */ HubertModel),
/* harmony export */ HubertPreTrainedModel: () => (/* binding */ HubertPreTrainedModel),
/* harmony export */ ImageMattingOutput: () => (/* binding */ ImageMattingOutput),
/* harmony export */ JAISLMHeadModel: () => (/* binding */ JAISLMHeadModel),
/* harmony export */ JAISModel: () => (/* binding */ JAISModel),
/* harmony export */ JAISPreTrainedModel: () => (/* binding */ JAISPreTrainedModel),
/* harmony export */ JinaCLIPModel: () => (/* binding */ JinaCLIPModel),
/* harmony export */ JinaCLIPPreTrainedModel: () => (/* binding */ JinaCLIPPreTrainedModel),
/* harmony export */ JinaCLIPTextModel: () => (/* binding */ JinaCLIPTextModel),
/* harmony export */ JinaCLIPVisionModel: () => (/* binding */ JinaCLIPVisionModel),
/* harmony export */ LlamaForCausalLM: () => (/* binding */ LlamaForCausalLM),
/* harmony export */ LlamaModel: () => (/* binding */ LlamaModel),
/* harmony export */ LlamaPreTrainedModel: () => (/* binding */ LlamaPreTrainedModel),
/* harmony export */ LlavaForConditionalGeneration: () => (/* binding */ LlavaForConditionalGeneration),
/* harmony export */ LlavaOnevisionForConditionalGeneration: () => (/* binding */ LlavaOnevisionForConditionalGeneration),
/* harmony export */ LlavaPreTrainedModel: () => (/* binding */ LlavaPreTrainedModel),
/* harmony export */ LongT5ForConditionalGeneration: () => (/* binding */ LongT5ForConditionalGeneration),
/* harmony export */ LongT5Model: () => (/* binding */ LongT5Model),
/* harmony export */ LongT5PreTrainedModel: () => (/* binding */ LongT5PreTrainedModel),
/* harmony export */ M2M100ForConditionalGeneration: () => (/* binding */ M2M100ForConditionalGeneration),
/* harmony export */ M2M100Model: () => (/* binding */ M2M100Model),
/* harmony export */ M2M100PreTrainedModel: () => (/* binding */ M2M100PreTrainedModel),
/* harmony export */ MBartForCausalLM: () => (/* binding */ MBartForCausalLM),
/* harmony export */ MBartForConditionalGeneration: () => (/* binding */ MBartForConditionalGeneration),
/* harmony export */ MBartForSequenceClassification: () => (/* binding */ MBartForSequenceClassification),
/* harmony export */ MBartModel: () => (/* binding */ MBartModel),
/* harmony export */ MBartPreTrainedModel: () => (/* binding */ MBartPreTrainedModel),
/* harmony export */ MPNetForMaskedLM: () => (/* binding */ MPNetForMaskedLM),
/* harmony export */ MPNetForQuestionAnswering: () => (/* binding */ MPNetForQuestionAnswering),
/* harmony export */ MPNetForSequenceClassification: () => (/* binding */ MPNetForSequenceClassification),
/* harmony export */ MPNetForTokenClassification: () => (/* binding */ MPNetForTokenClassification),
/* harmony export */ MPNetModel: () => (/* binding */ MPNetModel),
/* harmony export */ MPNetPreTrainedModel: () => (/* binding */ MPNetPreTrainedModel),
/* harmony export */ MT5ForConditionalGeneration: () => (/* binding */ MT5ForConditionalGeneration),
/* harmony export */ MT5Model: () => (/* binding */ MT5Model),
/* harmony export */ MT5PreTrainedModel: () => (/* binding */ MT5PreTrainedModel),
/* harmony export */ MarianMTModel: () => (/* binding */ MarianMTModel),
/* harmony export */ MarianModel: () => (/* binding */ MarianModel),
/* harmony export */ MarianPreTrainedModel: () => (/* binding */ MarianPreTrainedModel),
/* harmony export */ MaskFormerForInstanceSegmentation: () => (/* binding */ MaskFormerForInstanceSegmentation),
/* harmony export */ MaskFormerModel: () => (/* binding */ MaskFormerModel),
/* harmony export */ MaskFormerPreTrainedModel: () => (/* binding */ MaskFormerPreTrainedModel),
/* harmony export */ MaskedLMOutput: () => (/* binding */ MaskedLMOutput),
/* harmony export */ MgpstrForSceneTextRecognition: () => (/* binding */ MgpstrForSceneTextRecognition),
/* harmony export */ MgpstrModelOutput: () => (/* binding */ MgpstrModelOutput),
/* harmony export */ MgpstrPreTrainedModel: () => (/* binding */ MgpstrPreTrainedModel),
/* harmony export */ MistralForCausalLM: () => (/* binding */ MistralForCausalLM),
/* harmony export */ MistralModel: () => (/* binding */ MistralModel),
/* harmony export */ MistralPreTrainedModel: () => (/* binding */ MistralPreTrainedModel),
/* harmony export */ MobileBertForMaskedLM: () => (/* binding */ MobileBertForMaskedLM),
/* harmony export */ MobileBertForQuestionAnswering: () => (/* binding */ MobileBertForQuestionAnswering),
/* harmony export */ MobileBertForSequenceClassification: () => (/* binding */ MobileBertForSequenceClassification),
/* harmony export */ MobileBertModel: () => (/* binding */ MobileBertModel),
/* harmony export */ MobileBertPreTrainedModel: () => (/* binding */ MobileBertPreTrainedModel),
/* harmony export */ MobileLLMForCausalLM: () => (/* binding */ MobileLLMForCausalLM),
/* harmony export */ MobileLLMModel: () => (/* binding */ MobileLLMModel),
/* harmony export */ MobileLLMPreTrainedModel: () => (/* binding */ MobileLLMPreTrainedModel),
/* harmony export */ MobileNetV1ForImageClassification: () => (/* binding */ MobileNetV1ForImageClassification),
/* harmony export */ MobileNetV1Model: () => (/* binding */ MobileNetV1Model),
/* harmony export */ MobileNetV1PreTrainedModel: () => (/* binding */ MobileNetV1PreTrainedModel),
/* harmony export */ MobileNetV2ForImageClassification: () => (/* binding */ MobileNetV2ForImageClassification),
/* harmony export */ MobileNetV2Model: () => (/* binding */ MobileNetV2Model),
/* harmony export */ MobileNetV2PreTrainedModel: () => (/* binding */ MobileNetV2PreTrainedModel),
/* harmony export */ MobileNetV3ForImageClassification: () => (/* binding */ MobileNetV3ForImageClassification),
/* harmony export */ MobileNetV3Model: () => (/* binding */ MobileNetV3Model),
/* harmony export */ MobileNetV3PreTrainedModel: () => (/* binding */ MobileNetV3PreTrainedModel),
/* harmony export */ MobileNetV4ForImageClassification: () => (/* binding */ MobileNetV4ForImageClassification),
/* harmony export */ MobileNetV4Model: () => (/* binding */ MobileNetV4Model),
/* harmony export */ MobileNetV4PreTrainedModel: () => (/* binding */ MobileNetV4PreTrainedModel),
/* harmony export */ MobileViTForImageClassification: () => (/* binding */ MobileViTForImageClassification),
/* harmony export */ MobileViTModel: () => (/* binding */ MobileViTModel),
/* harmony export */ MobileViTPreTrainedModel: () => (/* binding */ MobileViTPreTrainedModel),
/* harmony export */ MobileViTV2ForImageClassification: () => (/* binding */ MobileViTV2ForImageClassification),
/* harmony export */ MobileViTV2Model: () => (/* binding */ MobileViTV2Model),
/* harmony export */ MobileViTV2PreTrainedModel: () => (/* binding */ MobileViTV2PreTrainedModel),
/* harmony export */ ModelOutput: () => (/* binding */ ModelOutput),
/* harmony export */ Moondream1ForConditionalGeneration: () => (/* binding */ Moondream1ForConditionalGeneration),
/* harmony export */ MptForCausalLM: () => (/* binding */ MptForCausalLM),
/* harmony export */ MptModel: () => (/* binding */ MptModel),
/* harmony export */ MptPreTrainedModel: () => (/* binding */ MptPreTrainedModel),
/* harmony export */ MultiModalityCausalLM: () => (/* binding */ MultiModalityCausalLM),
/* harmony export */ MultiModalityPreTrainedModel: () => (/* binding */ MultiModalityPreTrainedModel),
/* harmony export */ MusicgenForCausalLM: () => (/* binding */ MusicgenForCausalLM),
/* harmony export */ MusicgenForConditionalGeneration: () => (/* binding */ MusicgenForConditionalGeneration),
/* harmony export */ MusicgenModel: () => (/* binding */ MusicgenModel),
/* harmony export */ MusicgenPreTrainedModel: () => (/* binding */ MusicgenPreTrainedModel),
/* harmony export */ NomicBertModel: () => (/* binding */ NomicBertModel),
/* harmony export */ NomicBertPreTrainedModel: () => (/* binding */ NomicBertPreTrainedModel),
/* harmony export */ OPTForCausalLM: () => (/* binding */ OPTForCausalLM),
/* harmony export */ OPTModel: () => (/* binding */ OPTModel),
/* harmony export */ OPTPreTrainedModel: () => (/* binding */ OPTPreTrainedModel),
/* harmony export */ OlmoForCausalLM: () => (/* binding */ OlmoForCausalLM),
/* harmony export */ OlmoModel: () => (/* binding */ OlmoModel),
/* harmony export */ OlmoPreTrainedModel: () => (/* binding */ OlmoPreTrainedModel),
/* harmony export */ OpenELMForCausalLM: () => (/* binding */ OpenELMForCausalLM),
/* harmony export */ OpenELMModel: () => (/* binding */ OpenELMModel),
/* harmony export */ OpenELMPreTrainedModel: () => (/* binding */ OpenELMPreTrainedModel),
/* harmony export */ OwlViTForObjectDetection: () => (/* binding */ OwlViTForObjectDetection),
/* harmony export */ OwlViTModel: () => (/* binding */ OwlViTModel),
/* harmony export */ OwlViTPreTrainedModel: () => (/* binding */ OwlViTPreTrainedModel),
/* harmony export */ Owlv2ForObjectDetection: () => (/* binding */ Owlv2ForObjectDetection),
/* harmony export */ Owlv2Model: () => (/* binding */ Owlv2Model),
/* harmony export */ Owlv2PreTrainedModel: () => (/* binding */ Owlv2PreTrainedModel),
/* harmony export */ PatchTSMixerForPrediction: () => (/* binding */ PatchTSMixerForPrediction),
/* harmony export */ PatchTSMixerModel: () => (/* binding */ PatchTSMixerModel),
/* harmony export */ PatchTSMixerPreTrainedModel: () => (/* binding */ PatchTSMixerPreTrainedModel),
/* harmony export */ PatchTSTForPrediction: () => (/* binding */ PatchTSTForPrediction),
/* harmony export */ PatchTSTModel: () => (/* binding */ PatchTSTModel),
/* harmony export */ PatchTSTPreTrainedModel: () => (/* binding */ PatchTSTPreTrainedModel),
/* harmony export */ Phi3ForCausalLM: () => (/* binding */ Phi3ForCausalLM),
/* harmony export */ Phi3Model: () => (/* binding */ Phi3Model),
/* harmony export */ Phi3PreTrainedModel: () => (/* binding */ Phi3PreTrainedModel),
/* harmony export */ PhiForCausalLM: () => (/* binding */ PhiForCausalLM),
/* harmony export */ PhiModel: () => (/* binding */ PhiModel),
/* harmony export */ PhiPreTrainedModel: () => (/* binding */ PhiPreTrainedModel),
/* harmony export */ PreTrainedModel: () => (/* binding */ PreTrainedModel),
/* harmony export */ PretrainedMixin: () => (/* binding */ PretrainedMixin),
/* harmony export */ PvtForImageClassification: () => (/* binding */ PvtForImageClassification),
/* harmony export */ PvtModel: () => (/* binding */ PvtModel),
/* harmony export */ PvtPreTrainedModel: () => (/* binding */ PvtPreTrainedModel),
/* harmony export */ PyAnnoteForAudioFrameClassification: () => (/* binding */ PyAnnoteForAudioFrameClassification),
/* harmony export */ PyAnnoteModel: () => (/* binding */ PyAnnoteModel),
/* harmony export */ PyAnnotePreTrainedModel: () => (/* binding */ PyAnnotePreTrainedModel),
/* harmony export */ QuestionAnsweringModelOutput: () => (/* binding */ QuestionAnsweringModelOutput),
/* harmony export */ Qwen2ForCausalLM: () => (/* binding */ Qwen2ForCausalLM),
/* harmony export */ Qwen2Model: () => (/* binding */ Qwen2Model),
/* harmony export */ Qwen2PreTrainedModel: () => (/* binding */ Qwen2PreTrainedModel),
/* harmony export */ Qwen2VLForConditionalGeneration: () => (/* binding */ Qwen2VLForConditionalGeneration),
/* harmony export */ Qwen2VLPreTrainedModel: () => (/* binding */ Qwen2VLPreTrainedModel),
/* harmony export */ RTDetrForObjectDetection: () => (/* binding */ RTDetrForObjectDetection),
/* harmony export */ RTDetrModel: () => (/* binding */ RTDetrModel),
/* harmony export */ RTDetrObjectDetectionOutput: () => (/* binding */ RTDetrObjectDetectionOutput),
/* harmony export */ RTDetrPreTrainedModel: () => (/* binding */ RTDetrPreTrainedModel),
/* harmony export */ ResNetForImageClassification: () => (/* binding */ ResNetForImageClassification),
/* harmony export */ ResNetModel: () => (/* binding */ ResNetModel),
/* harmony export */ ResNetPreTrainedModel: () => (/* binding */ ResNetPreTrainedModel),
/* harmony export */ RoFormerForMaskedLM: () => (/* binding */ RoFormerForMaskedLM),
/* harmony export */ RoFormerForQuestionAnswering: () => (/* binding */ RoFormerForQuestionAnswering),
/* harmony export */ RoFormerForSequenceClassification: () => (/* binding */ RoFormerForSequenceClassification),
/* harmony export */ RoFormerForTokenClassification: () => (/* binding */ RoFormerForTokenClassification),
/* harmony export */ RoFormerModel: () => (/* binding */ RoFormerModel),
/* harmony export */ RoFormerPreTrainedModel: () => (/* binding */ RoFormerPreTrainedModel),
/* harmony export */ RobertaForMaskedLM: () => (/* binding */ RobertaForMaskedLM),
/* harmony export */ RobertaForQuestionAnswering: () => (/* binding */ RobertaForQuestionAnswering),
/* harmony export */ RobertaForSequenceClassification: () => (/* binding */ RobertaForSequenceClassification),
/* harmony export */ RobertaForTokenClassification: () => (/* binding */ RobertaForTokenClassification),
/* harmony export */ RobertaModel: () => (/* binding */ RobertaModel),
/* harmony export */ RobertaPreTrainedModel: () => (/* binding */ RobertaPreTrainedModel),
/* harmony export */ SamImageSegmentationOutput: () => (/* binding */ SamImageSegmentationOutput),
/* harmony export */ SamModel: () => (/* binding */ SamModel),
/* harmony export */ SamPreTrainedModel: () => (/* binding */ SamPreTrainedModel),
/* harmony export */ SapiensForDepthEstimation: () => (/* binding */ SapiensForDepthEstimation),
/* harmony export */ SapiensForNormalEstimation: () => (/* binding */ SapiensForNormalEstimation),
/* harmony export */ SapiensForSemanticSegmentation: () => (/* binding */ SapiensForSemanticSegmentation),
/* harmony export */ SapiensPreTrainedModel: () => (/* binding */ SapiensPreTrainedModel),
/* harmony export */ SegformerForImageClassification: () => (/* binding */ SegformerForImageClassification),
/* harmony export */ SegformerForSemanticSegmentation: () => (/* binding */ SegformerForSemanticSegmentation),
/* harmony export */ SegformerModel: () => (/* binding */ SegformerModel),
/* harmony export */ SegformerPreTrainedModel: () => (/* binding */ SegformerPreTrainedModel),
/* harmony export */ Seq2SeqLMOutput: () => (/* binding */ Seq2SeqLMOutput),
/* harmony export */ SequenceClassifierOutput: () => (/* binding */ SequenceClassifierOutput),
/* harmony export */ SiglipModel: () => (/* binding */ SiglipModel),
/* harmony export */ SiglipPreTrainedModel: () => (/* binding */ SiglipPreTrainedModel),
/* harmony export */ SiglipTextModel: () => (/* binding */ SiglipTextModel),
/* harmony export */ SiglipVisionModel: () => (/* binding */ SiglipVisionModel),
/* harmony export */ SpeechT5ForSpeechToText: () => (/* binding */ SpeechT5ForSpeechToText),
/* harmony export */ SpeechT5ForTextToSpeech: () => (/* binding */ SpeechT5ForTextToSpeech),
/* harmony export */ SpeechT5HifiGan: () => (/* binding */ SpeechT5HifiGan),
/* harmony export */ SpeechT5Model: () => (/* binding */ SpeechT5Model),
/* harmony export */ SpeechT5PreTrainedModel: () => (/* binding */ SpeechT5PreTrainedModel),
/* harmony export */ SqueezeBertForMaskedLM: () => (/* binding */ SqueezeBertForMaskedLM),
/* harmony export */ SqueezeBertForQuestionAnswering: () => (/* binding */ SqueezeBertForQuestionAnswering),
/* harmony export */ SqueezeBertForSequenceClassification: () => (/* binding */ SqueezeBertForSequenceClassification),
/* harmony export */ SqueezeBertModel: () => (/* binding */ SqueezeBertModel),
/* harmony export */ SqueezeBertPreTrainedModel: () => (/* binding */ SqueezeBertPreTrainedModel),
/* harmony export */ StableLmForCausalLM: () => (/* binding */ StableLmForCausalLM),
/* harmony export */ StableLmModel: () => (/* binding */ StableLmModel),
/* harmony export */ StableLmPreTrainedModel: () => (/* binding */ StableLmPreTrainedModel),
/* harmony export */ Starcoder2ForCausalLM: () => (/* binding */ Starcoder2ForCausalLM),
/* harmony export */ Starcoder2Model: () => (/* binding */ Starcoder2Model),
/* harmony export */ Starcoder2PreTrainedModel: () => (/* binding */ Starcoder2PreTrainedModel),
/* harmony export */ Swin2SRForImageSuperResolution: () => (/* binding */ Swin2SRForImageSuperResolution),
/* harmony export */ Swin2SRModel: () => (/* binding */ Swin2SRModel),
/* harmony export */ Swin2SRPreTrainedModel: () => (/* binding */ Swin2SRPreTrainedModel),
/* harmony export */ SwinForImageClassification: () => (/* binding */ SwinForImageClassification),
/* harmony export */ SwinModel: () => (/* binding */ SwinModel),
/* harmony export */ SwinPreTrainedModel: () => (/* binding */ SwinPreTrainedModel),
/* harmony export */ T5ForConditionalGeneration: () => (/* binding */ T5ForConditionalGeneration),
/* harmony export */ T5Model: () => (/* binding */ T5Model),
/* harmony export */ T5PreTrainedModel: () => (/* binding */ T5PreTrainedModel),
/* harmony export */ TableTransformerForObjectDetection: () => (/* binding */ TableTransformerForObjectDetection),
/* harmony export */ TableTransformerModel: () => (/* binding */ TableTransformerModel),
/* harmony export */ TableTransformerObjectDetectionOutput: () => (/* binding */ TableTransformerObjectDetectionOutput),
/* harmony export */ TableTransformerPreTrainedModel: () => (/* binding */ TableTransformerPreTrainedModel),
/* harmony export */ TokenClassifierOutput: () => (/* binding */ TokenClassifierOutput),
/* harmony export */ TrOCRForCausalLM: () => (/* binding */ TrOCRForCausalLM),
/* harmony export */ TrOCRPreTrainedModel: () => (/* binding */ TrOCRPreTrainedModel),
/* harmony export */ UniSpeechForCTC: () => (/* binding */ UniSpeechForCTC),
/* harmony export */ UniSpeechForSequenceClassification: () => (/* binding */ UniSpeechForSequenceClassification),
/* harmony export */ UniSpeechModel: () => (/* binding */ UniSpeechModel),
/* harmony export */ UniSpeechPreTrainedModel: () => (/* binding */ UniSpeechPreTrainedModel),
/* harmony export */ UniSpeechSatForAudioFrameClassification: () => (/* binding */ UniSpeechSatForAudioFrameClassification),
/* harmony export */ UniSpeechSatForCTC: () => (/* binding */ UniSpeechSatForCTC),
/* harmony export */ UniSpeechSatForSequenceClassification: () => (/* binding */ UniSpeechSatForSequenceClassification),
/* harmony export */ UniSpeechSatModel: () => (/* binding */ UniSpeechSatModel),
/* harmony export */ UniSpeechSatPreTrainedModel: () => (/* binding */ UniSpeechSatPreTrainedModel),
/* harmony export */ ViTForImageClassification: () => (/* binding */ ViTForImageClassification),
/* harmony export */ ViTMAEModel: () => (/* binding */ ViTMAEModel),
/* harmony export */ ViTMAEPreTrainedModel: () => (/* binding */ ViTMAEPreTrainedModel),
/* harmony export */ ViTMSNForImageClassification: () => (/* binding */ ViTMSNForImageClassification),
/* harmony export */ ViTMSNModel: () => (/* binding */ ViTMSNModel),
/* harmony export */ ViTMSNPreTrainedModel: () => (/* binding */ ViTMSNPreTrainedModel),
/* harmony export */ ViTModel: () => (/* binding */ ViTModel),
/* harmony export */ ViTPreTrainedModel: () => (/* binding */ ViTPreTrainedModel),
/* harmony export */ VisionEncoderDecoderModel: () => (/* binding */ VisionEncoderDecoderModel),
/* harmony export */ VitMatteForImageMatting: () => (/* binding */ VitMatteForImageMatting),
/* harmony export */ VitMattePreTrainedModel: () => (/* binding */ VitMattePreTrainedModel),
/* harmony export */ VitPoseForPoseEstimation: () => (/* binding */ VitPoseForPoseEstimation),
/* harmony export */ VitPosePreTrainedModel: () => (/* binding */ VitPosePreTrainedModel),
/* harmony export */ VitsModel: () => (/* binding */ VitsModel),
/* harmony export */ VitsModelOutput: () => (/* binding */ VitsModelOutput),
/* harmony export */ VitsPreTrainedModel: () => (/* binding */ VitsPreTrainedModel),
/* harmony export */ Wav2Vec2BertForCTC: () => (/* binding */ Wav2Vec2BertForCTC),
/* harmony export */ Wav2Vec2BertForSequenceClassification: () => (/* binding */ Wav2Vec2BertForSequenceClassification),
/* harmony export */ Wav2Vec2BertModel: () => (/* binding */ Wav2Vec2BertModel),
/* harmony export */ Wav2Vec2BertPreTrainedModel: () => (/* binding */ Wav2Vec2BertPreTrainedModel),
/* harmony export */ Wav2Vec2ForAudioFrameClassification: () => (/* binding */ Wav2Vec2ForAudioFrameClassification),
/* harmony export */ Wav2Vec2ForCTC: () => (/* binding */ Wav2Vec2ForCTC),
/* harmony export */ Wav2Vec2ForSequenceClassification: () => (/* binding */ Wav2Vec2ForSequenceClassification),
/* harmony export */ Wav2Vec2Model: () => (/* binding */ Wav2Vec2Model),
/* harmony export */ Wav2Vec2PreTrainedModel: () => (/* binding */ Wav2Vec2PreTrainedModel),
/* harmony export */ WavLMForAudioFrameClassification: () => (/* binding */ WavLMForAudioFrameClassification),
/* harmony export */ WavLMForCTC: () => (/* binding */ WavLMForCTC),
/* harmony export */ WavLMForSequenceClassification: () => (/* binding */ WavLMForSequenceClassification),
/* harmony export */ WavLMForXVector: () => (/* binding */ WavLMForXVector),
/* harmony export */ WavLMModel: () => (/* binding */ WavLMModel),
/* harmony export */ WavLMPreTrainedModel: () => (/* binding */ WavLMPreTrainedModel),
/* harmony export */ WeSpeakerResNetModel: () => (/* binding */ WeSpeakerResNetModel),
/* harmony export */ WeSpeakerResNetPreTrainedModel: () => (/* binding */ WeSpeakerResNetPreTrainedModel),
/* harmony export */ WhisperForConditionalGeneration: () => (/* binding */ WhisperForConditionalGeneration),
/* harmony export */ WhisperModel: () => (/* binding */ WhisperModel),
/* harmony export */ WhisperPreTrainedModel: () => (/* binding */ WhisperPreTrainedModel),
/* harmony export */ XLMForQuestionAnswering: () => (/* binding */ XLMForQuestionAnswering),
/* harmony export */ XLMForSequenceClassification: () => (/* binding */ XLMForSequenceClassification),
/* harmony export */ XLMForTokenClassification: () => (/* binding */ XLMForTokenClassification),
/* harmony export */ XLMModel: () => (/* binding */ XLMModel),
/* harmony export */ XLMPreTrainedModel: () => (/* binding */ XLMPreTrainedModel),
/* harmony export */ XLMRobertaForMaskedLM: () => (/* binding */ XLMRobertaForMaskedLM),
/* harmony export */ XLMRobertaForQuestionAnswering: () => (/* binding */ XLMRobertaForQuestionAnswering),
/* harmony export */ XLMRobertaForSequenceClassification: () => (/* binding */ XLMRobertaForSequenceClassification),
/* harmony export */ XLMRobertaForTokenClassification: () => (/* binding */ XLMRobertaForTokenClassification),
/* harmony export */ XLMRobertaModel: () => (/* binding */ XLMRobertaModel),
/* harmony export */ XLMRobertaPreTrainedModel: () => (/* binding */ XLMRobertaPreTrainedModel),
/* harmony export */ XLMWithLMHeadModel: () => (/* binding */ XLMWithLMHeadModel),
/* harmony export */ XVectorOutput: () => (/* binding */ XVectorOutput),
/* harmony export */ YolosForObjectDetection: () => (/* binding */ YolosForObjectDetection),
/* harmony export */ YolosModel: () => (/* binding */ YolosModel),
/* harmony export */ YolosObjectDetectionOutput: () => (/* binding */ YolosObjectDetectionOutput),
/* harmony export */ YolosPreTrainedModel: () => (/* binding */ YolosPreTrainedModel)
/* harmony export */ });
/* harmony import */ var _configs_js__WEBPACK_IMPORTED_MODULE_0__ = __webpack_require__(/*! ./configs.js */ "./src/configs.js");
/* harmony import */ var _backends_onnx_js__WEBPACK_IMPORTED_MODULE_1__ = __webpack_require__(/*! ./backends/onnx.js */ "./src/backends/onnx.js");
/* harmony import */ var _utils_dtypes_js__WEBPACK_IMPORTED_MODULE_2__ = __webpack_require__(/*! ./utils/dtypes.js */ "./src/utils/dtypes.js");
/* harmony import */ var _utils_generic_js__WEBPACK_IMPORTED_MODULE_3__ = __webpack_require__(/*! ./utils/generic.js */ "./src/utils/generic.js");
/* harmony import */ var _utils_core_js__WEBPACK_IMPORTED_MODULE_4__ = __webpack_require__(/*! ./utils/core.js */ "./src/utils/core.js");
/* harmony import */ var _utils_hub_js__WEBPACK_IMPORTED_MODULE_5__ = __webpack_require__(/*! ./utils/hub.js */ "./src/utils/hub.js");
/* harmony import */ var _utils_constants_js__WEBPACK_IMPORTED_MODULE_6__ = __webpack_require__(/*! ./utils/constants.js */ "./src/utils/constants.js");
/* harmony import */ var _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__ = __webpack_require__(/*! ./generation/logits_process.js */ "./src/generation/logits_process.js");
/* harmony import */ var _generation_configuration_utils_js__WEBPACK_IMPORTED_MODULE_8__ = __webpack_require__(/*! ./generation/configuration_utils.js */ "./src/generation/configuration_utils.js");
/* harmony import */ var _utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__ = __webpack_require__(/*! ./utils/tensor.js */ "./src/utils/tensor.js");
/* harmony import */ var _utils_image_js__WEBPACK_IMPORTED_MODULE_10__ = __webpack_require__(/*! ./utils/image.js */ "./src/utils/image.js");
/* harmony import */ var _utils_maths_js__WEBPACK_IMPORTED_MODULE_11__ = __webpack_require__(/*! ./utils/maths.js */ "./src/utils/maths.js");
/* harmony import */ var _generation_stopping_criteria_js__WEBPACK_IMPORTED_MODULE_12__ = __webpack_require__(/*! ./generation/stopping_criteria.js */ "./src/generation/stopping_criteria.js");
/* harmony import */ var _generation_logits_sampler_js__WEBPACK_IMPORTED_MODULE_13__ = __webpack_require__(/*! ./generation/logits_sampler.js */ "./src/generation/logits_sampler.js");
/* harmony import */ var _env_js__WEBPACK_IMPORTED_MODULE_14__ = __webpack_require__(/*! ./env.js */ "./src/env.js");
/* harmony import */ var _models_whisper_generation_whisper_js__WEBPACK_IMPORTED_MODULE_15__ = __webpack_require__(/*! ./models/whisper/generation_whisper.js */ "./src/models/whisper/generation_whisper.js");
/* harmony import */ var _models_whisper_common_whisper_js__WEBPACK_IMPORTED_MODULE_16__ = __webpack_require__(/*! ./models/whisper/common_whisper.js */ "./src/models/whisper/common_whisper.js");
/**
* @file Definitions of all models available in Transformers.js.
*
* **Example:** Load and run an `AutoModel`.
*
* ```javascript
* import { AutoModel, AutoTokenizer } from '@huggingface/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased');
* let model = await AutoModel.from_pretrained('Xenova/bert-base-uncased');
*
* let inputs = await tokenizer('I love transformers!');
* let { logits } = await model(inputs);
* // Tensor {
* // data: Float32Array(183132) [-7.117443084716797, -7.107812881469727, -7.092104911804199, ...]
* // dims: (3) [1, 6, 30522],
* // type: "float32",
* // size: 183132,
* // }
* ```
*
* We also provide other `AutoModel`s (listed below), which you can use in the same way as the Python library. For example:
*
* **Example:** Load and run an `AutoModelForSeq2SeqLM`.
* ```javascript
* import { AutoModelForSeq2SeqLM, AutoTokenizer } from '@huggingface/transformers';
*
* let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small');
* let model = await AutoModelForSeq2SeqLM.from_pretrained('Xenova/t5-small');
*
* let { input_ids } = await tokenizer('translate English to German: I love transformers!');
* let outputs = await model.generate(input_ids);
* let decoded = tokenizer.decode(outputs[0], { skip_special_tokens: true });
* // 'Ich liebe Transformatoren!'
* ```
*
* @module models
*/
//////////////////////////////////////////////////
// Model types: used internally
const MODEL_TYPES = {
EncoderOnly: 0,
EncoderDecoder: 1,
Seq2Seq: 2,
Vision2Seq: 3,
DecoderOnly: 4,
MaskGeneration: 5,
ImageTextToText: 6,
Musicgen: 7,
MultiModality: 8,
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
// Helper functions
// NOTE: These will be populated fully later
const MODEL_TYPE_MAPPING = new Map();
const MODEL_NAME_TO_CLASS_MAPPING = new Map();
const MODEL_CLASS_TO_NAME_MAPPING = new Map();
/**
* Constructs an InferenceSession using a model file located at the specified path.
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {string} fileName The name of the model file.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
* @private
*/
async function getSession(pretrained_model_name_or_path, fileName, options) {
const custom_config = options.config?.['transformers.js_config'] ?? {};
let device = options.device ?? custom_config.device;
if (device && typeof device !== 'string') {
if (device.hasOwnProperty(fileName)) {
device = device[fileName];
} else {
console.warn(`device not specified for "${fileName}". Using the default device.`);
device = null;
}
}
// If the device is not specified, we use the default (supported) execution providers.
const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */(
device ?? (_env_js__WEBPACK_IMPORTED_MODULE_14__.apis.IS_NODE_ENV ? 'cpu' : 'wasm')
);
const executionProviders = (0,_backends_onnx_js__WEBPACK_IMPORTED_MODULE_1__.deviceToExecutionProviders)(selectedDevice);
// If options.dtype is specified, we use it to choose the suffix for the model file.
// Otherwise, we use the default dtype for the device.
let dtype = options.dtype ?? custom_config.dtype;
if (typeof dtype !== 'string') {
if (dtype && dtype.hasOwnProperty(fileName)) {
dtype = dtype[fileName];
} else {
dtype = _utils_dtypes_js__WEBPACK_IMPORTED_MODULE_2__.DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? _utils_dtypes_js__WEBPACK_IMPORTED_MODULE_2__.DATA_TYPES.fp32;
console.warn(`dtype not specified for "${fileName}". Using the default dtype (${dtype}) for this device (${selectedDevice}).`);
}
}
const selectedDtype = /** @type {import("./utils/dtypes.js").DataType} */(dtype);
if (!_utils_dtypes_js__WEBPACK_IMPORTED_MODULE_2__.DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) {
throw new Error(`Invalid dtype: ${selectedDtype}. Should be one of: ${Object.keys(_utils_dtypes_js__WEBPACK_IMPORTED_MODULE_2__.DATA_TYPES).join(', ')}`);
} else if (selectedDtype === _utils_dtypes_js__WEBPACK_IMPORTED_MODULE_2__.DATA_TYPES.fp16 && selectedDevice === 'webgpu' && !(await (0,_utils_dtypes_js__WEBPACK_IMPORTED_MODULE_2__.isWebGpuFp16Supported)())) {
throw new Error(`The device (${selectedDevice}) does not support fp16.`);
}
// Only valid for models with a decoder
const kv_cache_dtype = custom_config.kv_cache_dtype
? (typeof custom_config.kv_cache_dtype === 'string'
? custom_config.kv_cache_dtype
: custom_config.kv_cache_dtype[selectedDtype] ?? 'float32')
: undefined;
if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`);
}
const session_config = {
dtype: selectedDtype,
kv_cache_dtype,
}
// Construct the model file name
const suffix = _utils_dtypes_js__WEBPACK_IMPORTED_MODULE_2__.DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
const session_options = { ...options.session_options };
// Overwrite `executionProviders` if not specified
session_options.executionProviders ??= executionProviders;
// Overwrite `freeDimensionOverrides` if specified in config and not set in session options
const free_dimension_overrides = custom_config.free_dimension_overrides;
if (free_dimension_overrides) {
session_options.freeDimensionOverrides ??= free_dimension_overrides;
} else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
console.warn(
'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' +
'When `free_dimension_overrides` is not set, you may experience significant performance degradation.'
);
}
const bufferPromise = (0,_utils_hub_js__WEBPACK_IMPORTED_MODULE_5__.getModelFile)(pretrained_model_name_or_path, modelFileName, true, options);
// handle onnx external data files
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
/** @type {Promise<{path: string, data: Uint8Array}>[]} */
let externalDataPromises = [];
if (use_external_data_format && (
use_external_data_format === true ||
(
typeof use_external_data_format === 'object' &&
use_external_data_format.hasOwnProperty(fileName) &&
use_external_data_format[fileName] === true
)
)) {
if (_env_js__WEBPACK_IMPORTED_MODULE_14__.apis.IS_NODE_ENV) {
throw new Error('External data format is not yet supported in Node.js');
}
const path = `${fileName}${suffix}.onnx_data`;
const fullPath = `${options.subfolder ?? ''}/${path}`;
externalDataPromises.push(new Promise(async (resolve, reject) => {
const data = await (0,_utils_hub_js__WEBPACK_IMPORTED_MODULE_5__.getModelFile)(pretrained_model_name_or_path, fullPath, true, options);
resolve({ path, data })
}));
} else if (session_options.externalData !== undefined) {
externalDataPromises = session_options.externalData.map(async (ext) => {
// if the external data is a string, fetch the file and replace the string with its content
if (typeof ext.data === "string") {
const ext_buffer = await (0,_utils_hub_js__WEBPACK_IMPORTED_MODULE_5__.getModelFile)(pretrained_model_name_or_path, ext.data, true, options);
return { ...ext, data: ext_buffer };
}
return ext;
});
}
if (externalDataPromises.length > 0) {
session_options.externalData = await Promise.all(externalDataPromises);
}
if (selectedDevice === 'webgpu') {
const shapes = (0,_configs_js__WEBPACK_IMPORTED_MODULE_0__.getKeyValueShapes)(options.config, {
prefix: 'present',
});
if (Object.keys(shapes).length > 0 && !(0,_backends_onnx_js__WEBPACK_IMPORTED_MODULE_1__.isONNXProxy)()) {
// Only set preferredOutputLocation if shapes are present and we aren't proxying ONNX
/** @type {Record<string, import('onnxruntime-common').Tensor.DataLocation>} */
const preferredOutputLocation = {};
for (const key in shapes) {
preferredOutputLocation[key] = 'gpu-buffer';
}
session_options.preferredOutputLocation = preferredOutputLocation;
}
}
const buffer = await bufferPromise;
return { buffer, session_options, session_config };
}
/**
* Helper function to create multiple InferenceSession objects.
*
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {Record<string, string>} names The names of the model files to load.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of InferenceSession objects.
* @private
*/
async function constructSessions(pretrained_model_name_or_path, names, options) {
return Object.fromEntries(await Promise.all(
Object.keys(names).map(async (name) => {
const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
const session = await (0,_backends_onnx_js__WEBPACK_IMPORTED_MODULE_1__.createInferenceSession)(buffer, session_options, session_config);
return [name, session];
})
));
}
/**
* Helper function to load multiple optional configuration files
* @param {string} pretrained_model_name_or_path The path to the directory containing the config file.
* @param {Record<string, string>} names The names of the config files to load.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the configs.
* @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of configuration objects.
* @private
*/
async function getOptionalConfigs(pretrained_model_name_or_path, names, options) {
return Object.fromEntries(await Promise.all(
Object.keys(names).map(async (name) => {
const config = await (0,_utils_hub_js__WEBPACK_IMPORTED_MODULE_5__.getModelJSON)(pretrained_model_name_or_path, names[name], false, options);
return [name, config];
})
));
}
/**
* Validate model inputs
* @param {Object} session The InferenceSession object that will be run.
* @param {Object} inputs The inputs to check.
* @returns {Record<string, Tensor>} The checked inputs.
* @throws {Error} If any inputs are missing.
* @private
*/
function validateInputs(session, inputs) {
/**
* NOTE: Create either a shallow or deep copy based on `onnx.wasm.proxy`
* @type {Record<string, Tensor>}
*/
const checkedInputs = Object.create(null);
const missingInputs = [];
for (const inputName of session.inputNames) {
const tensor = inputs[inputName];
// Rare case where one of the model's input names corresponds to a built-in
// object name (e.g., toString), which would cause a simple (!tensor) check to fail,
// because it's not undefined but a function.
if (!(tensor instanceof _utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.Tensor)) {
missingInputs.push(inputName);
continue;
}
// NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker
// boundary, transferring ownership to the worker and invalidating the tensor.
// So, in this case, we simply sacrifice a clone for it.
checkedInputs[inputName] = (0,_backends_onnx_js__WEBPACK_IMPORTED_MODULE_1__.isONNXProxy)() ? tensor.clone() : tensor;
}
if (missingInputs.length > 0) {
throw new Error(
`An error occurred during model execution: "Missing the following inputs: ${missingInputs.join(', ')}.`);
}
const numInputsProvided = Object.keys(inputs).length;
const numInputsNeeded = session.inputNames.length;
if (numInputsProvided > numInputsNeeded) {
// No missing inputs, but too many inputs were provided.
// Warn the user and ignore the extra inputs.
let ignored = Object.keys(inputs).filter(inputName => !session.inputNames.includes(inputName));
console.warn(`WARNING: Too many inputs were provided (${numInputsProvided} > ${numInputsNeeded}). The following inputs will be ignored: "${ignored.join(', ')}".`);
}
return checkedInputs;
}
/**
* Executes an InferenceSession using the specified inputs.
* NOTE: `inputs` must contain at least the input names of the model.
* - If additional inputs are passed, they will be ignored.
* - If inputs are missing, an error will be thrown.
*
* @param {Object} session The InferenceSession object to run.
* @param {Object} inputs An object that maps input names to input tensors.
* @returns {Promise<Object>} A Promise that resolves to an object that maps output names to output tensors.
* @private
*/
async function sessionRun(session, inputs) {
const checkedInputs = validateInputs(session, inputs);
try {
// pass the original ort tensor
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
let output = await session.run(ortFeed);
output = replaceTensors(output);
return output;
} catch (e) {
// This usually occurs when the inputs are of the wrong type.
console.error(`An error occurred during model execution: "${e}".`);
console.error('Inputs given to model:', checkedInputs)
throw e;
}
}
/**
* Replaces ONNX Tensor objects with custom Tensor objects to support additional functions.
* @param {Object} obj The object to replace tensor objects in.
* @returns {Object} The object with tensor objects replaced by custom Tensor objects.
* @private
*/
function replaceTensors(obj) {
for (let prop in obj) {
if ((0,_backends_onnx_js__WEBPACK_IMPORTED_MODULE_1__.isONNXTensor)(obj[prop])) {
obj[prop] = new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.Tensor(obj[prop]);
} else if (typeof obj[prop] === 'object') {
replaceTensors(obj[prop]);
}
}
return obj;
}
/**
* Converts an array or Tensor of integers to an int64 Tensor.
* @param {any[]|Tensor} items The input integers to be converted.
* @returns {Tensor} The int64 Tensor with the converted values.
* @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length.
* @private
*/
function toI64Tensor(items) {
if (items instanceof _utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.Tensor) {
return items;
}
// items is an array
if (items.length === 0) {
throw Error("items must be non-empty");
}
if (Array.isArray(items[0])) {
// batched
if (items.some(x => x.length !== items[0].length)) {
throw Error("Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' and/or 'truncation=True' to have batched tensors with the same length.")
}
return new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.Tensor('int64',
BigInt64Array.from(items.flat().map(x => BigInt(x))),
[items.length, items[0].length]
);
} else {
//flat
return new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.Tensor('int64',
BigInt64Array.from(items.map(x => BigInt(x))),
[1, items.length]
);
}
}
/**
* Creates a boolean tensor with a single value.
* @param {boolean} value The value of the tensor.
* @returns {Tensor} The boolean tensor.
* @private
*/
function boolTensor(value) {
return new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.Tensor('bool', [value], [1]);
}
// JS doesn't support mixins, so we define some reused functions here, and allow "this" to be passed in
/**
* Perform forward pass on the seq2seq model (both encoder and decoder).
* @param {Object} self The seq2seq model object.
* @param {Object} model_inputs The input object for the model containing encoder and decoder inputs.
* @returns {Promise<Seq2SeqLMOutput>} Promise that resolves with the output of the seq2seq model.
* @private
*/
async function seq2seqForward(self, model_inputs) {
let { encoder_outputs, input_ids, decoder_input_ids, ...other_decoder_inputs } = model_inputs;
// Encode if needed
if (!encoder_outputs) {
const encoder_inputs = (0,_utils_core_js__WEBPACK_IMPORTED_MODULE_4__.pick)(model_inputs, self.sessions['model'].inputNames);
// Encoder outputs are not given, so we must compute them.
encoder_outputs = (await encoderForward(self, encoder_inputs)).last_hidden_state;
}
other_decoder_inputs.input_ids = decoder_input_ids;
other_decoder_inputs.encoder_hidden_states = encoder_outputs;
if (self.sessions['decoder_model_merged'].inputNames.includes('encoder_attention_mask')) {
other_decoder_inputs.encoder_attention_mask = model_inputs.attention_mask
}
const decoderResults = await decoderForward(self, other_decoder_inputs, true);
return decoderResults;
}
/**
* Forward pass of an encoder model.
* @param {Object} self The encoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @returns {Promise<Object>} The model's outputs.
* @private
*/
async function encoderForward(self, model_inputs) {
const session = self.sessions['model'];
const encoderFeeds = (0,_utils_core_js__WEBPACK_IMPORTED_MODULE_4__.pick)(model_inputs, session.inputNames);
if (session.inputNames.includes('inputs_embeds') && !encoderFeeds.inputs_embeds) {
if (!model_inputs.input_ids) {
throw new Error('Both `input_ids` and `inputs_embeds` are missing in the model inputs.');
}
encoderFeeds.inputs_embeds = await self.encode_text({ input_ids: model_inputs.input_ids });
}
if (session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) {
// Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it,
// but they weren't created by the tokenizer.
encoderFeeds.token_type_ids = new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.Tensor(
'int64',
new BigInt64Array(encoderFeeds.input_ids.data.length),
encoderFeeds.input_ids.dims
)
}
return await sessionRun(session, encoderFeeds);
}
/**
* Forward pass of a decoder model.
* @param {Object} self The decoder model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @returns {Promise<Object>} The logits and past key values.
* @private
*/
async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
const session = self.sessions[
is_encoder_decoder ? 'decoder_model_merged' : 'model'
]
const { past_key_values, ...new_model_inputs } = model_inputs;
if (session.inputNames.includes('use_cache_branch')) {
new_model_inputs.use_cache_branch = boolTensor(!!past_key_values);
}
if (session.inputNames.includes('position_ids') && new_model_inputs.attention_mask && !new_model_inputs.position_ids) {
new_model_inputs.position_ids = createPositionIds(new_model_inputs, past_key_values);
}
// Unpack the `past_key_values` object into model inputs
self.addPastKeyValues(new_model_inputs, past_key_values);
// Select only the inputs that are needed for the current session
const fixed = (0,_utils_core_js__WEBPACK_IMPORTED_MODULE_4__.pick)(new_model_inputs, session.inputNames);
return await sessionRun(session, fixed);
}
/**
* Forward pass of an image-text-to-text model.
* @param {Object} self The image-text-to-text model model.
* @param {Object} model_inputs The input data to be used for the forward pass.
* @param {Tensor} [model_inputs.input_ids=null]
* @param {Tensor} [model_inputs.attention_mask=null]
* @param {Tensor} [model_inputs.pixel_values=null]
* @param {Tensor} [model_inputs.position_ids=null]
* @param {Tensor} [model_inputs.inputs_embeds=null]
* @param {Tensor} [model_inputs.past_key_values=null]
* @param {Object} [model_inputs.generation_config=null]
* @param {Object} [model_inputs.logits_processor=null]
* @returns {Promise<Tensor>} The model's output tensor
* @private
*/
async function imageTextToTextForward(self, {
// Produced by the tokenizer/processor:
input_ids = null,
attention_mask = null,
pixel_values = null,
// Used during generation:
position_ids = null,
inputs_embeds = null,
past_key_values = null,
// Generic generation parameters
generation_config = null,
logits_processor = null,
// TODO: needed?
...kwargs
}) {
if (!inputs_embeds) {
// 1. Extract the input embeddings
inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
// 2. Possibly, merge text and images
if (pixel_values && input_ids.dims[1] !== 1) {
const image_features = await self.encode_image({ pixel_values, ...kwargs });
({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({
image_features,
inputs_embeds,
input_ids,
attention_mask,
}));
} else if (past_key_values && pixel_values && input_ids.dims[1] === 1) {
// This is the case when we are generating with cache
const target_length = input_ids.dims[1]; // always 1
const past_length = Object.values(past_key_values)[0].dims.at(-2);
attention_mask = (0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.cat)([
(0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.ones)([input_ids.dims[0], past_length]),
attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]),
], 1);
}
}
if (!position_ids) {
if (self.config.model_type === 'qwen2_vl') {
// Special case for qwen2_vl models
// @ts-ignore
const { image_grid_thw, video_grid_thw } = kwargs;
[position_ids] = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
}
}
const outputs = await decoderForward(self, {
inputs_embeds,
past_key_values,
attention_mask,
position_ids,
generation_config,
logits_processor,
}, true);
return outputs;
}
/**
* Helper function to perform the following:
* ```python
* x = attention_mask.long().cumsum(-1) - 1
* x.masked_fill_(attention_mask == 0, 1)
* ```
* @param {Tensor} attention_mask
* @returns {{data: BigInt64Array, dims: number[]}}
*/
function cumsum_masked_fill(attention_mask) {
const [bz, seq_len] = attention_mask.dims;
const attn_mask_data = attention_mask.data;
const data = new BigInt64Array(attn_mask_data.length);
for (let i = 0; i < bz; ++i) {
const start = i * seq_len;
let sum = BigInt(0);
for (let j = 0; j < seq_len; ++j) {
const index = start + j;
if (attn_mask_data[index] === 0n) {
data[index] = BigInt(1);
} else { // === 1n
data[index] = sum;
sum += attn_mask_data[index];
}
}
}
return { data, dims: attention_mask.dims };
}
/**
* If the model supports providing position_ids, we create position_ids on the fly for batch generation,
* by computing the cumulative sum of the attention mask along the sequence length dimension.
*
* Equivalent to:
* ```python
* position_ids = attention_mask.long().cumsum(-1) - 1
* position_ids.masked_fill_(attention_mask == 0, 1)
* if past_key_values:
* position_ids = position_ids[:, -input_ids.shape[1] :]
* ```
*/
function createPositionIds(model_inputs, past_key_values = null) {
const { input_ids, inputs_embeds, attention_mask } = model_inputs;
const { data, dims } = cumsum_masked_fill(attention_mask);
let position_ids = new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.Tensor('int64', data, dims);
if (past_key_values) {
const offset = -(input_ids ?? inputs_embeds).dims.at(1);
position_ids = position_ids.slice(null, [offset, null]);
}
return position_ids;
}
function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
if (model_inputs.past_key_values) {
const past_length = Object.values(model_inputs.past_key_values)[0].dims.at(-2);
const { input_ids, attention_mask } = model_inputs;
// Keep only the unprocessed tokens:
// 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
// some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
// input)
if (attention_mask && attention_mask.dims[1] > input_ids.dims[1]) {
// NOTE: not needed since we only pass the generated tokens to the next forward pass
// const offset = -(attention_mask.dims[1] - past_length);
// model_inputs.input_ids = input_ids.slice(null, [offset, null]);
}
// 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
// We can discard input_ids based on the past_length.
else if (past_length < input_ids.dims[1]) {
// NOTE: Required for phi models.
// See https://github.com/huggingface/transformers/issues/30809#issuecomment-2111918479 for more information.
model_inputs.input_ids = input_ids.slice(null, [past_length, null]);
}
// 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
else {
if (
// NOTE: Only used by VLMs (!= so that null matches undefined)
self.config.image_token_index != null &&
// Equivalent to `self.config.image_token_index in input_ids` (== so that int matches bigint)
input_ids.data.some(x => x == self.config.image_token_index)
) {
// TODO: Support multiple image tokens
const num_image_tokens = self.config.num_image_tokens;
if (!num_image_tokens) {
throw new Error('`num_image_tokens` is missing in the model configuration.');
}
const num_new_tokens = input_ids.dims[1] - (past_length - num_image_tokens);
model_inputs.input_ids = input_ids.slice(null, [-num_new_tokens, null]);
// TODO: The attention mask should be formed from the attention mask passed in model_inputs
model_inputs.attention_mask = (0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.ones)([1, past_length + num_new_tokens]);
}
}
}
return model_inputs;
}
function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
if (model_inputs.past_key_values) {
input_ids = input_ids.map(x => [x.at(-1)]);
}
return {
...model_inputs,
decoder_input_ids: toI64Tensor(input_ids),
};
}
function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
if (self.config.is_encoder_decoder) {
return encoder_decoder_prepare_inputs_for_generation(self, ...args);
} else {
return decoder_prepare_inputs_for_generation(self, ...args);
}
}
function multimodality_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
const has_past_key_values = !!model_inputs.past_key_values;
if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
if (has_past_key_values) {
model_inputs.input_ids = (0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.cat)([
model_inputs.input_ids,
model_inputs.input_ids,
], 0)
// NOTE: attention_mask handled in generation
} else {
model_inputs.input_ids = (0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.cat)([
model_inputs.input_ids,
(0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.full_like)(model_inputs.input_ids, BigInt(generation_config.pad_token_id)),
], 0);
model_inputs.attention_mask = (0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.cat)([
model_inputs.attention_mask,
(0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.full_like)(model_inputs.attention_mask, 0n),
], 0);
}
}
if (has_past_key_values || !model_inputs.pixel_values) {
model_inputs.pixel_values = (0,_utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.full)([0, 0, 3, 384, 384], 1.0);
}
if (has_past_key_values) {
const num_img_tokens = 0;
const num_text_tokens = 1;
const has_image = num_img_tokens > 0 ? 1 : 0;
const batch_size = 1;
model_inputs.images_seq_mask = new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.Tensor(
'bool',
new Array(num_img_tokens + num_text_tokens).fill(true).fill(false, 0, num_text_tokens),
[batch_size, num_img_tokens + num_text_tokens],
);
model_inputs.images_emb_mask = new _utils_tensor_js__WEBPACK_IMPORTED_MODULE_9__.Tensor(
'bool',
new Array(num_img_tokens).fill(!!has_image),
[batch_size, 1, num_img_tokens],
);
}
return model_inputs;
}
//////////////////////////////////////////////////
//////////////////////////////////////////////////
/**
* A base class for pre-trained models that provides the model configuration and an ONNX session.
*/
class PreTrainedModel extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_3__.Callable {
main_input_name = 'input_ids';
forward_params = ['input_ids', 'attention_mask'];
/**
* Creates a new instance of the `PreTrainedModel` class.
* @param {import('./configs.js').PretrainedConfig} config The model configuration.
* @param {Record<string, any>} sessions The inference sessions for the model.
* @param {Record<string, Object>} configs Additional configuration files (e.g., generation_config.json).
*/
constructor(config, sessions, configs) {
super();
this.config = config;
this.sessions = sessions;
this.configs = configs;
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
this.can_generate = false;
this._forward = null;
this._prepare_inputs_for_generation = null;
switch (modelType) {
case MODEL_TYPES.DecoderOnly:
this.can_generate = true;
this._forward = decoderForward;
this._prepare_inputs_for_generation = decoder_prepare_inputs_for_generation;
break;
case MODEL_TYPES.Seq2Seq:
case MODEL_TYPES.Vision2Seq:
case MODEL_TYPES.Musicgen:
this.can_generate = true;
this._forward = seq2seqForward;
this._prepare_inputs_for_generation = encoder_decoder_prepare_inputs_for_generation;
break;
case MODEL_TYPES.EncoderDecoder:
this._forward = seq2seqForward;
break;
case MODEL_TYPES.ImageTextToText:
this.can_generate = true;
this._forward = imageTextToTextForward;
this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation;
break;
case MODEL_TYPES.MultiModality:
this.can_generate = true;
this._prepare_inputs_for_generation = multimodality_prepare_inputs_for_generation;
break;
default:
// should be MODEL_TYPES.EncoderOnly
this._forward = encoderForward;
break;
}
if (this.can_generate) {
this.forward_params.push('past_key_values');
}
/** @type {import('./configs.js').TransformersJSConfig} */
this.custom_config = this.config['transformers.js_config'] ?? {};
}
/**
* Disposes of all the ONNX sessions that were created during inference.
* @returns {Promise<unknown[]>} An array of promises, one for each ONNX session that is being disposed.
* @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry
*/
async dispose() {
const promises = [];
for (const session of Object.values(this.sessions)) {
if (session?.handler?.dispose) {
promises.push(session.handler.dispose())
}
}
return await Promise.all(promises);
}
/**
* Instantiate one of the model classes of the library from a pretrained model.
*
* The model class to instantiate is selected based on the `model_type` property of the config object
* (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible)
*
* @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either:
* - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing model weights, e.g., `./my_model_directory/`.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
*
* @returns {Promise<PreTrainedModel>} A new instance of the `PreTrainedModel` class.
*/
static async from_pretrained(pretrained_model_name_or_path, {
progress_callback = null,
config = null,
cache_dir = null,
local_files_only = false,
revision = 'main',
model_file_name = null,
subfolder = 'onnx',
device = null,
dtype = null,
use_external_data_format = null,
session_options = {},
} = {}) {
let options = {
progress_callback,
config,
cache_dir,
local_files_only,
revision,
model_file_name,
subfolder,
device,
dtype,
use_external_data_format,
session_options,
}
const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
const modelType = MODEL_TYPE_MAPPING.get(modelName);
config = options.config = await _configs_js__WEBPACK_IMPORTED_MODULE_0__.AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
let info;
if (modelType === MODEL_TYPES.DecoderOnly) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.MaskGeneration) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'vision_encoder',
prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder',
}, options),
]);
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
]);
} else if (modelType === MODEL_TYPES.ImageTextToText) {
const sessions = {
embed_tokens: 'embed_tokens',
vision_encoder: 'vision_encoder',
decoder_model_merged: 'decoder_model_merged',
}
if (config.is_encoder_decoder) {
sessions['model'] = 'encoder_model';
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.Musicgen) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: 'text_encoder',
decoder_model_merged: 'decoder_model_merged',
encodec_decode: 'encodec_decode',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else if (modelType === MODEL_TYPES.MultiModality) {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
prepare_inputs_embeds: 'prepare_inputs_embeds',
model: 'language_model',
lm_head: 'lm_head',
gen_head: 'gen_head',
gen_img_embeds: 'gen_img_embeds',
image_decode: 'image_decode',
}, options),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
]);
} else { // should be MODEL_TYPES.EncoderOnly
if (modelType !== MODEL_TYPES.EncoderOnly) {
console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${_utils_constants_js__WEBPACK_IMPORTED_MODULE_6__.GITHUB_ISSUE_URL}.`)
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
]);
}
// @ts-ignore
return new this(config, ...info);
}
/**
* Runs the model with the provided inputs
* @param {Object} model_inputs Object containing input tensors
* @returns {Promise<Object>} Object containing output tensors
*/
async _call(model_inputs) {
return await this.forward(model_inputs);
}
/**
* Forward method for a pretrained model. If not overridden by a subclass, the correct forward method
* will be chosen based on the model type.
* @param {Object} model_inputs The input data to the model in the format specified in the ONNX model.
* @returns {Promise<Object>} The output data from the model in the format specified in the ONNX model.
* @throws {Error} This method must be implemented in subclasses.
*/
async forward(model_inputs) {
return await this._forward(this, model_inputs);
}
/**
* Get the model's generation config, if it exists.
* @returns {GenerationConfig|null} The model's generation config if it exists, otherwise `null`.
*/
get generation_config() {
return this.configs?.generation_config ?? null;
}
/**
* This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`]
* instances used for multinomial sampling.
* @param {GenerationConfig} generation_config The generation config.
* @returns {LogitsProcessorList} generation_config
*/
_get_logits_warper(generation_config) {
// instantiate warpers list
const warpers = new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.LogitsProcessorList();
if (generation_config.temperature !== null && generation_config.temperature !== 1.0) {
warpers.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.TemperatureLogitsWarper(generation_config.temperature));
}
if (generation_config.top_k !== null && generation_config.top_k !== 0) {
// TODO: add min_tokens_to_keep
warpers.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.TopKLogitsWarper(generation_config.top_k));
}
if (generation_config.top_p !== null && generation_config.top_p < 1.0) {
// TODO: add min_tokens_to_keep
warpers.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.TopPLogitsWarper(generation_config.top_p));
}
return warpers;
}
/**
* @param {GenerationConfig} generation_config
* @param {number} input_ids_seq_length The starting sequence length for the input ids.
* @returns {LogitsProcessorList}
* @private
*/
_get_logits_processor(
generation_config,
input_ids_seq_length,
// encoder_input_ids, TODO
// prefix_allowed_tokens_fn, TODO
logits_processor = null
) {
const processors = new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.LogitsProcessorList();
// if (generation_config.diversity_penalty !== null && generation_config.diversity_penalty > 0.0) {
// processors.push(new HammingDiversityLogitsProcessor(
// generation_config.diversity_penalty,
// generation_config.num_beams,
// generation_config.num_beam_groups
// ));
// }
// if (generation_config.encoder_repetition_penalty !== null && generation_config.encoder_repetition_penalty !== 1.0) {
// processors.push(new EncoderRepetitionPenaltyLogitsProcessor(
// generation_config.encoder_repetition_penalty,
// encoder_input_ids
// ));
// }
if (generation_config.repetition_penalty !== null && generation_config.repetition_penalty !== 1.0) {
processors.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.RepetitionPenaltyLogitsProcessor(generation_config.repetition_penalty));
}
if (generation_config.no_repeat_ngram_size !== null && generation_config.no_repeat_ngram_size > 0) {
processors.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size));
}
// if (generation_config.encoder_no_repeat_ngram_size !== null && generation_config.encoder_no_repeat_ngram_size > 0) {
// if (this.config.is_encoder_decoder) {
// processors.push(new EncoderNoRepeatNGramLogitsProcessor(
// generation_config.encoder_no_repeat_ngram_size,
// encoder_input_ids
// ));
// } else {
// throw new Error("It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture");
// }
// }
if (generation_config.bad_words_ids !== null) {
processors.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
}
if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
processors.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
}
if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
processors.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.MinNewTokensLengthLogitsProcessor(
input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id
));
}
// if (prefix_allowed_tokens_fn !== null) {
// processors.push(new PrefixConstrainedLogitsProcessor(
// prefix_allowed_tokens_fn,
// generation_config.num_beams / generation_config.num_beam_groups
// ));
// }
if (generation_config.forced_bos_token_id !== null) {
processors.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id));
}
if (generation_config.forced_eos_token_id !== null) {
processors.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.ForcedEOSTokenLogitsProcessor(
generation_config.max_length,
generation_config.forced_eos_token_id
));
}
// if (generation_config.remove_invalid_values === true) {
// processors.push(new InfNanRemoveLogitsProcessor());
// }
// if (generation_config.exponential_decay_length_penalty !== null) {
// processors.push(new ExponentialDecayLengthPenalty(
// generation_config.exponential_decay_length_penalty,
// generation_config.eos_token_id,
// input_ids_seq_length
// ));
// }
// if (generation_config.suppress_tokens !== null) {
// processors.push(new SuppressTokensLogitsProcessor(generation_config.suppress_tokens));
// }
if (generation_config.begin_suppress_tokens !== null) {
const begin_index = (input_ids_seq_length > 1 || generation_config.forced_bos_token_id === null)
? input_ids_seq_length
: input_ids_seq_length + 1;
processors.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index));
}
// DEPRECATED: https://github.com/huggingface/transformers/pull/29485
// if (generation_config.forced_decoder_ids !== null) {
// processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids));
// }
// 8. prepare batched CFG externally
if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
processors.push(new _generation_logits_process_js__WEBPACK_IMPORTED_MODULE_7__.ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale));
}
if (logits_processor !== null) {
processors.extend(logits_processor)
}
// `LogitNormalization` should always be the last logit processor, when present
// if (generation_config.renormalize_logits === true) {
// processors.push(new LogitNormalization());
// }
return processors;
}
/**
* This function merges multiple generation configs together to form a final generation config to be used by the model for text generation.
* It first creates an empty `GenerationConfig` object, then it applies the model's own `generation_config` property to it. Finally, if a `generation_config` object was passed in the arguments, it overwrites the corresponding properties in the final config with those of the passed config object.
* @param {GenerationConfig|null} generation_config A `GenerationConfig` object containing generation parameters.
* @param {Object} kwargs Additional generation parameters to be used in place of those in the `generation_config` object.
* @returns {GenerationConfig} The final generation config object to be used by the model for text generation.
*/
_prepare_generation_config(generation_config, kwargs, cls = _generation_configuration_utils_js__WEBPACK_IMPORTED_MODULE_8__.GenerationConfig) {
// Create empty generation config (contains defaults)
// We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
const config = { ...this.config };
for (const key of ["decoder", "generator", "text_config"]) {
// Special case: some models have generation attributes set in the decoder.
// Use them if still unset in the generation config.
if (key in config) {
Object.assign(config, config[key]);
}
}
const gen_config = new cls(config);
// Apply model's generation config, if it exists
Object.assign(gen_config, this.generation_config ?? {});
// Next, use any generation config specified by the user
// when calling `generate`
if (generation_config) {
Object.assign(gen_config, generation_config);
}
// Finally, if any kwargs were passed, use them to overwrite
if (kwargs) {
Object.assign(gen_config, (0,_utils_core_js__WEBPACK_IMPORTED_MODULE_4__.pick)(kwargs, Object.getOwnPropertyNames(gen_config)));
}
return gen_config;
}
/**
*
* @param {GenerationConfig} generation_config
* @param {StoppingCriteriaList} [stopping_criteria=null]
*/
_get_stopping_criteria(generation_config, stopping_criteria = null) {
const criteria = new _generation_stopping_criteria_js__WEBPACK_IMPORTED_MODULE_12__.StoppingCriteriaList();
if (generation_config.max_length !== null) {
criteria.push(new _generation_stopping_criteria_js__WEBPACK_IMPORTED_MODULE_12__.MaxLengthCriteria(
generation_config.max_length,
this.config.max_position_embeddings ?? null,
));
}
// if (generation_config.max_time !== null) {
// criteria.push(new MaxTimeCriteria(generation_config.max_time));
// }
if (generation_config.eos_token_id !== null) {
criteria.push(new _generation_stopping_criteria_js__WEBPACK_IMPORTED_MODULE_12__.EosTokenCriteria(generation_config.eos_token_id));
}
if (stopping_criteria) {
criteria.extend(stopping_criteria);
}
return criteria;
}
/** | | |