STRAPS_LOCALHOST/lib/pose/XGBoostPredictor.ts

164 lines
6.4 KiB
TypeScript

import modelData from '@/public/models/xgb_activity_model.json';
type TreeNode = {
split_indices?: number;
split_conditions?: number;
yes?: number;
no?: number;
missing?: number;
split_type?: number;
leaf?: number;
children?: TreeNode[];
};
// The raw JSON structure from XGBoost dump
interface XGBModelDump {
learner: {
gradient_booster: {
model: {
trees: Array<{
base_weights: number[];
default_left: number[];
id: number;
left_children: number[];
loss_changes: number[];
parents: number[];
right_children: number[];
split_conditions: number[];
split_indices: number[];
split_type: number[];
sum_hessian: number[];
tree_param: {
num_deleted: string;
num_feature: string;
num_nodes: string;
size_leaf_vector: string;
};
}>;
gbtree_model_param: {
num_parallel_tree: string;
num_trees: string;
};
};
};
learner_model_param: {
base_score: string; // e.g. "[0.2, 0.3, 0.5]" or single float
num_class: string;
}
};
}
export class XGBoostPredictor {
private model: XGBModelDump;
private numTrees: number;
private numClass: number;
private baseScores: number[];
constructor() {
this.model = modelData as unknown as XGBModelDump;
this.numTrees = parseInt(this.model.learner.gradient_booster.model.gbtree_model_param.num_trees);
this.numClass = parseInt(this.model.learner.learner_model_param.num_class);
// Parse base score (often represented as an array string or a single float string)
const baseScoreRaw = this.model.learner.learner_model_param.base_score;
if (baseScoreRaw.startsWith('[')) {
try {
this.baseScores = JSON.parse(baseScoreRaw);
} catch (e) {
// Fallback manually parsing if JSON.parse fails on some formats
console.error("Error parsing base_score", e);
this.baseScores = Array(this.numClass).fill(0.5);
}
} else {
this.baseScores = Array(this.numClass).fill(parseFloat(baseScoreRaw));
}
}
public predict(features: number[]): number[] {
// Initialize scores with base_margin (inverse link of base_score usually, but for XGBoost multi-class
// with 'multi:softprob', it usually starts at 0.5 before the tree sums if using raw margin,
// but let's assume we sum the raw tree outputs).
// Actually, XGBoost stores the raw margins.
const rawScores = new Array(this.numClass).fill(0.5);
// NOTE: In strict XGBoost implementation, the initial prediction is 0.5 (logit)
// if base_score is 0.5. For accurate results, we should check `base_score` parameter.
// If base_scores are provided, we should convert them to margins if boosting starts from them.
// Usually, sum = base_margin + sum(tree_outputs)
// Convert base scores to margins (logit)
// margin = ln(p / (1-p)) is for binary. For multiclass, it's more complex.
// Let's rely on standard additive behavior: rawScores starts at 0?
// Or starts at the initial margin.
// Let's use 0.0 effectively and rely on Trees
// (This might require tuning, but standard dump execution typically sums weights)
const treeScores = new Array(this.numClass).fill(0);
const trees = this.model.learner.gradient_booster.model.trees;
for (let i = 0; i < this.numTrees; i++) {
const tree = trees[i];
const classIdx = i % this.numClass; // Trees are interleaved for classes 0, 1, 2, 0, 1, 2...
let nodeId = 0; // Start at root
// Traverse
while (true) {
// Check if leaf
// In this JSON format, children arrays contain -1 for no child.
// But we must check if the current node is a split or leaf.
// The arrays (split_indices, etc.) are indexed by node ID.
// Wait, the JSON format provided is aggressive: "left_children", "right_children" are arrays.
const leftChild = tree.left_children[nodeId];
const rightChild = tree.right_children[nodeId];
// If leaf, left child is usually -1 (or similar indicator)
// However, look at the values.
// If index is valid split, proceed.
if (leftChild === -1 && rightChild === -1) {
// Leaf node
// Weight is in base_weights[nodeId]
treeScores[classIdx] += tree.base_weights[nodeId];
break;
}
// Split
const featureIdx = tree.split_indices[nodeId];
const threshold = tree.split_conditions[nodeId];
const defaultLeft = tree.default_left[nodeId] === 1;
const featureVal = features[featureIdx];
// Missing value handling (if feature is NaN, go default)
if (featureVal === undefined || isNaN(featureVal)) {
nodeId = defaultLeft ? leftChild : rightChild;
} else {
if (featureVal < threshold) {
nodeId = leftChild;
} else {
nodeId = rightChild;
}
}
}
}
// Softmax
// First add base margin?
// For 'multi:softprob', output is softmax(raw_score + base_margin)
// If base_score=[0.5, 0.5, 0.5], base_margin ~ 0.
return this.softmax(treeScores);
}
private softmax(logits: number[]): number[] {
const maxLogit = Math.max(...logits);
const scores = logits.map(l => Math.exp(l - maxLogit));
const sumScores = scores.reduce((a, b) => a + b, 0);
return scores.map(s => s / sumScores);
}
}