ReClOB³T: регрессия и классификация на забывчивых двоичных деревьях с градиентным бустингом

26 апреля 2023, 13:30
Stanislav Korotky
8
162

ReClObBT - это исследовательский скрипт для технологии регрессии и классификации на забывчивых двоичных деревьях с градиентным бустингом (Regression and Classification on Oblivious Binary Boosted Trees). Поясним каждый из терминов в названии, чтобы очертить суть задачи и способы решений - это уместно, поскольку существует много типов деревьев со своей спецификой.

Деревья решений являются одним из самых популярных инструментов машинного обучения, наравне с нейронным сетями. Они так называются, поскольку состоят из правил, которые удобно визуально отображать как ветвления - в каждом из них указывается условие и пороговое значение, разделяющие набор данных по оси одного из предикторов. Правила, то есть ответвления, нанизываются друг за другом (выбираются по специальным алгоритмам) до тех пор, пока в сформировавшихся подмножествах данных не окажутся достаточно "чистые" согласованные ответы относительно классов (в случае задачи классификации) или усредненных значений (в случае регрессии). Подобные конечные узлы дерева без правил называются "листьями".

Если каждое правило имеет только одно условие и порождает две ветви, такое дерево называется двоичным (или бинарным). Как раз о таких деревьях пойдет далее речь.

Схема гипотетического дерева решений торговой системы

Схема гипотетического дерева решений торговой системы

Точку деления данных на 2 набора в каждом правиле рассчитывают с помощью различных метрик, оценивающих качество результата. Например, при регрессии обычно вычисляется ошибка MSE: чем меньше она становится после деления, тем более предпочтительна предполагаемая точка. В задачах классификации используются другие метрики, но тут есть один нюанс, о котором будет сказано ниже.

Для тех, кто не очень хорошо знаком с деревьями, вкратце обобщим их положительные стороны.

  1. Не требуется подготавливать входные данные (нормировать диапазон, подгонять друг к другу размеры классов)
  2. Не требуется проверять предикторы на их важность (алгоритм построения дерева просто не выберет неважные предикторы для включения в правила)
  3. Можно обрабатывать не только числовые, но и категорийные предикторы
  4. Можно в явном виде выделять правила получаемой модели (что обычно затруднено в НС, хотя и к ним можно применить так называемое прореживание)

Последний пункт относится только к случаям, когда модель состоит из одного дерева. Однако для улучшения характеристик модели зачастую задействуют много деревьев. Это можно организовать за счет параллельного или последовательного построения множества деревьев, с последующим объединением их результатов - также параллельным или последовательным. Параллельный метод называется - bagging, а последовательный - boosting. Именно бустинг применяется в рассматриваемом скрипте.

При бустинге каждое дерево, добавляемое в модель, обучается в направлении поставленной цели (желаемого решения), исправляя ошибки в ответе предыдущего дерева. Тем самым ответ системы постепенно улучшается в направлении цели, а эти мелкие шаги и олицетворяют локальные градиенты.

Забывчивыми называют деревья, у которых на каждом уровне ветвления используется одно и то же правило для всего дерева. Иными словами, в расчет не берется тот факт, что на предыдущих уровнях уже применялись какие-то правила, которые сформировали уникальные подмножества входных данных в каждой ветви. Таким образом, если бы дерево не было забывчивым, в каждой из ветвей следовало бы рассчитать собственное отдельное правило, наиболее оптимальное только для этой ветви. Однако по принципу забывчивости, выбирается общее правило для всего уровня. Это делает забывчивые деревья слабыми "решателями". Из-за этого их не применяют по отдельности, а только в комитетах типа бустинга. Слабость является в некотором смысле и преимуществом, так как предотвращает переобучение (которое характерно для незабывчивого дерева без ограничений в росте), а также упрощает алгоритм и затраты памяти.

Раз уж зашла об этом речь, упомянем основные отрицательные стороны деревьев.

  1. Большой расход ресурсов (память и время обучения)
  2. Потенциальная опасность переобучения (всегда можно построить достаточно большое дерево, чтобы в точности "упаковать" все записи данных в листья)

Переобучение можно нивелировать благодаря градиентному бустингу за счет отбрасывания нескольких последних деревьев обученной цепочки.

Градиентный бустинг сказывается еще на одной особенности в задачах классификации. Поскольку движение по градиенту целевой функции предполагает её определение на всем диапазоне чисел, вероятности классификации, лежащие в отрезке [0, 1], не могут быть использованы напрямую. Вместо этого вероятности переводят в логарифм отношений вероятностей (двух классов или требуемого класса и вне этого класса). Тем самым получается так называемое значение logodds, от -бесконечности до +бесконечности. Однако к подобным значениям уже нельзя применять метрики качества классификации (информативности), такие как энтропия или индекс Gini. Поэтому во время построения деревьев классификации с градиентным бустингом используется метрика аналогичная MSE, но со специфическим названием - Residual Sum of Squares (RSS). То есть, алгоритм построения дерева минимизирует сумму квадратов остатков/невязок logodds (здесь остатками для проектируемой ветви выступают разницы между целевыми значениями попадающих в неё записей и усредненным значением по всей ветви).

Но достаточно теории. Более подробно ознакомиться с принципами классификации и регрессии при бустинге деревьев можно в материалах, доступных в Интернете. 

Цель данного скрипта - проверка технологии в трейдинге, непосредственно средствами MQL5, без привлечения сторонних программ.

Если тестирование покажет приемлемые результаты, можно будет собрать библиотеку для встраивания в MQL-программы.

Скрипт поддерживает 2 режима работы: классификация и регрессия. Классификация может выполняться по двум классам (бинарная) или нескольким (multi-class, multi-label). Классы могут задаваться в одной или нескольких самостоятельных колонках. В процессе работы скрипт в любом случае создаст в матрице с данными дополнительные колонки - по одной на каждый класс. Все они добавляются с правого края матрицы. В случае регрессии целевой является только одна колонка в данных, она должна быть самой правой.

Каждый класс считается независимо, то есть скрипт не применяет softmax, сумма вероятностей не сводится к 1. Эту постобработку можно выполнить по требованию уже к полученным результатам.

В зависимости от режима качество модели оценивается с помощью MSE (в случае регрессии) или трех характеристик (в случае классификации) - Accuracy, Binary Cross-Entropy по всем классам и Brier. Также для классификации в лог выводится confusion matrix (матрица перекрестных ошибок), precision (доля правильных ответов среди отнесенных к конкретному классу) и recall (полнота охвата класса в правильных ответах по этому классу).

Скрипт подразумевает 2 раздельных этапа: создание/обучение цепочки деревьев и последующее использование готовой цепочки деревьев.


Настройка

Входные параметры разделены по смыслу на несколько групп.

* CSV Import *

  • string FileName - имя CSV-файла с исходными данными; если оставить его пустым, скрипт выполняет регрессию предопределенной смеси двух синусоид и выводит их график (см. скриншот ниже); в CSV-файле первая строка предполагается с заголовками колонок;
  • string Delimiter - символ-разделитель для CSV, по умолчанию ",";
  • string DropColumns - номера или названия колонок в CSV-файле, которые нужно выбросить из рассмотрения;
  • string TargetColumns - номера или названия колонок в CSV-файле с обозначениями целевых классов, через запятую; может быть указана только одна колонка, если все классы сведены в неё; классы можно задавать числами или метками, как и категориальные предикторы; для регрессии параметр должен быть пустым (регрессия всегда выполняется по последней колонке);
  • int TargetMixCount - количество классов в колонке TargetColumns, если в ней помечены сразу все классы; по умолчанию 0, то есть предполагается, что каждый класс отмечается в собственной колонке или это бинарная классификации с кодированием 0 и 1;
  • int TargetMixBase - номер начального класса, когда их много в одной колонке; по умолчанию 0, классы нумеруются с 0;

* Common Settings *

  • int NumberOfTrees - количество деревьев для построения (при обучении) или штатной эксплуатации (позволяет отбросить часть последних деревьев);
  • int RandomSeed - инициализация случайного генератора; влияет только на генерацию демо-синусоид и последовательность начальных отладочных записей в журнале по DebugExcerpt (см. ниже);

* New Trees Generation Settings *

  • int TreeDepth - максимальная глубина каждого дерева; рекомендуется начать со значений 50-100 и оценить затраты ресурсов и качество обучения, а затем подстраивать; значение по умолчанию 0 приводит к автоматическому вычислению глубины как суммы количества предикторов и классов;
  • bool EarlyStop - автоматическая остановка обучения при ухудшении характеристик модели (для классификации должны ухудшиться все три характеристики, указанные выше);
  • bool DropTrailingTrees - при сохранении обученной цепочки отбрасывать последние деревья, в которых не было улучшения Accuracy;

* Ready-Made Trees Loading Settings *

  • string TreeFileName - название файла с сохраненной ранее обученной цепочкой деревьев; после каждого обучения скрипт автоматически создает файл с именем вида prefix-Nt-Nf-Nc(N)-name-timestamp.tre, где prefix - "cls" или "reg" в зависимости от режима, Nt - общее количество колонок, Nf - количество предикторов, Nc - количество классов, (N) - количество колонок с классами, name - название исходного CSV-файла с данными, timestamp - текущая временная метка;
  • bool SaveTestOutput - опция для сохранения CSV-файла с результирующими колонками классификации или регрессии; по умолчанию false, и оценочные результаты выводятся только в журнал;

* Classification Settings *

  • double LR - скорость обучения, по умолчанию 0.1;
  • bool StrictAccuracy - опция расчета Accuracy (принадлежности к классу) с порогом вероятности 0.5; по умолчанию - false, и точность считается исходя из распознавания класса по максимальному значению из всех классов; такой подход эквивалентен динамическому порогу;

* Regression Settings *

  • double EarlyStopPrecision - минимальное изменение MSE в задаче регрессии, когда сработает EarlyStop;

* Auxiliary *

  • int DebugLog - уровень подробности вывода в журнал; 0 - штатный минимальный лог, 3 - максимально подробный с выводом деревьев, поэтому может быть очень большим!
  • int DebugExcerpt - количество записей данных, выводимых в журнал для сверки данных; всегда выводятся первые записи, поэтому смена RandomSeed приведет к изменению набора;
  • bool InheritParentMeasure - опция включения/отключения "хитрого" способа контроля метрики при построении ветвлений; когда опция равна true (по умолчанию), метрика родительской ветви копируется в две дочерние ветви (и относительно этой величины определяется выигрыш качества при выборе точки их последующего деления); когда опция равна false, метрики двух дочерних ветвей пересчитываются с учетом только что сделанного разбиения в родительском правиле (по идее, это канонический способ, но есть нюансы - см. далее);  
  • string Quantization - размер шага квантования значений в указанных (поименованных) колонках, например, 'capital-gain=500,capital-loss=250' или '"alcohol"=0.1,"volatile acidity"=0.01' (здесь двойные кавычки нужны, потому что они есть в исходном файле с данными, см. пример вин);
  • string NumberEncodingFile - имя внешнего текстового файла с заданной числовой кодировкой категориальных колонок; по умолчанию пусто, и категориям автоматически присваиваются номера, исходя из хэшей (для меток) и статистики использования в классах;

Получающаяся глубина дерева может быть меньше параметра TreeDepth, если на каком-то уровне не найдется улучшающего ситуацию деления. Кроме того, после построения дерево проходит операцию обрезки рудиментарных листьев, которые не изменяют качество классификации (это специфика забывчивых деревьев, так как общее правило уровня в них может быть бесполезным в части ветвей).

Параметр InheritParentMeasure, равный true, можно в некотором роде интерпретировать как передачу метрики через один уровень, что видимо способно сделать дерево более оптимальным решателем в глобальном смысле. В проведенных тестах значение true дает лучшие результаты при классификации, но для регрессии лучше подходило значение false.

График демонстрации регрессии смеси синусоид

График демонстрации регрессии смеси синусоид


Применение

В процессе работы на график выводится комментарий и изменение статуса в журнал (заданной детализации), вроде следующего:

Reading CSV-file iris.csv
Categorical vocabularies: 0
150 records read (0 empty skipped), 7 columns in 00:00s
TargetColumns: '4,5,6'
4 5 6
Drops: ''
Mode: Multi-column multi-classification
Tree depth: 5
Number of trees: 50
Early stop: true
Learning rate: 0.1
Strict accuracy calculation: false
Random seed: 1
Stats by class:
50 50 50
...
Update(0):
Accuracy: 50 of 150 (33.33%), Brier score: 0.524834
Loss (BCE): 0.966780633850818,  [0]=0.966781 [1]=0.966781 [2]=0.966781, Depth:5
Update(1):
Accuracy: 143 of 150 (95.33%), Brier score: 0.424359
Loss (BCE): 0.8023493195198749,  [0]=0.788662 [1]=0.809189 [2]=0.809197, Depth:5
...
Update(48):
Accuracy: 144 of 150 (96.00%), Brier score: 0.0177732
Loss (BCE): 0.06413326967875786,  [0]=0.000771 [1]=0.094280 [2]=0.097348, Depth:5
Update(49):
Accuracy: 144 of 150 (96.00%), Brier score: 0.017721
Loss (BCE): 0.06463890519140751,  [0]=0.000658 [1]=0.095022 [2]=0.098237, Depth:5
* Best accuracy at 38-th tree (latest are discarded) *
Trees are saved in cls-7-4-3(3)-iris-1682448483.tre
* Boosting result on 39 trees:
Target vs Forecast (excerpt):
15 targets (first 3 lines) and 15 results (last 3 lines) are shown below (every line holds same class probs per element):
[[1,1,1,0,0,0,0,1,1,1,0,0,1,0,0]
 [0,0,0,0,1,0,0,0,0,0,0,1,0,0,0]
 [0,0,0,1,0,1,1,0,0,0,1,0,0,1,1]
 [0.999,0.999,0.999,0.005,0.005,0.005,0.005,0.999,0.999,0.999,0.005,0.005,0.999,0.005,0.005]
 [0.005,0.005,0.005,0.006,0.995,0.006,0.625,0.005,0.005,0.005,0.334,0.999,0.005,0.006,0.006]
 [0.005,0.005,0.005,0.999,0.03,0.999,0.778,0.005,0.005,0.005,0.92,0.005,0.005,0.999,0.999]]
Accuracy: 144 of 150 (96.00%), Brier score: 0.0185313
Confusion matrix:
[[50,0,0]
 [0,45,5]
 [0,1,49]]
Precision by class (true predictions):
[100,97.82608695652173,90.74074074074075]
Recall/selectivity by class (completeness):
[100,90,98]
Overall Loss (BCE): 0.06127315840405371, by class:  [0]=0.003772 [1]=0.089278 [2]=0.090770

Затянувшееся обучение можно досрочно прервать, остановив скрипт из меню графика. При этом уже построенные деревья будут сохранены в tre-файл.

Контролируйте расход памяти терминалом в процессе обучения на больших наборах данных - деревья растут в геометрической прогрессии.

Скрипт проверялся на нескольких тестовых примерах, в основном для классификации. К данной публикации прилагается несколько файлов с данными и соответствующие настроечные set-файлы.


Примеры

* play.csv *

Играть или не играть в гольф. Колонки-предикторы:

  • Outlook - категория
  • Temperature - число
  • Humidity - число
  • Windy - категория

Класс 0 или 1 в последней колонке:

  • Play(0) or Don't Play(1)

Настроечный файл play4.set.

Обратите внимание, что для категориальных предикторов в журнале всегда указываются назначенные им индексы (из хэшей и статистики). Эти словари хранятся в tre-файлах вместе с деревьями.

The following constants are assigned to categorical labels:
Outlook(0):0=Sunny,1=Rain,2=Overcast,
Windy(3):0=True,1=False,

* iris.csv *

Три класса цветов, заданных числовыми предикторами:

  • sepal-length
  • sepal-width
  • petal-length
  • petal-width

Принадлежность каждой записи к классу задана с помощью one-hot encoding (0 или 1, причем единица только одна в строке) в последних трех колонках:

  • Setosa
  • Versicolour
  • Virginica

Таким образом, в параметрах имеем TargetColumns=4,5,6. Можно было бы написать и так: TargetColumns=Setosa,Versicolour,Virginica.

Настроечный файл iris3.set.


* balance-scale.csv *

Набор для проверки сбалансированности механических весов, в которых регулируется расстояние и тяжесть грузиков. Колонка с тремя классами-категориями (B (баланс), R (перевес вправо), L (перевес влево)) идет в самом начале и называется:

  • Balanced

Настройка классов достигается указанием параметров TargetColumns=0, TargetMixCount=3, TargetMixBase=0.

Колонка классов будет автоматически сдвинута в конец рабочей матрицы, в которой и выполняются расчеты.

Эта особенность с единообразным размещением классов в правых колонках облегчает чтение подробных логов (когда выводятся сами записи и градиенты подстройки ошибок).

Предикторы-числа указаны в следующих колонках:

  • Left-Weight
  • Left-Distance
  • Right-Weight
  • Right-Distance

Настроечный файл balance2.set.

Особенностью данного примера является то, что категориальные классы (со значениями метками) всегда сортируются по алфавиту, о чем сообщается в логе:

...
Swapping class column to the right: 0 <-> 4
Permutation in class column 4:
    [label] [hash] [value]
[0] "B"         66       0
[1] "L"         76       1
[2] "R"         82       2
...

* happy.csv *

Результаты опроса жителей об удовлетворенности жизнью в конкретном населенном пункте в зависимости от нескольких характеристик места. Признак того, счастлив ли житель или нет, указан в первой колонке:

  • Happy - 1 или 0

Предикторы-числа, в пятибальной системе, следуют в колонках:

  • Services
  • Costs
  • Schools
  • Police
  • Streets
  • Events

Настроечный файл happy1.set.


* winequality-red.csv *

Оценка качества вина по его характеристикам.

Числовые предикторы расположены в колонках (кавычки были в csv-файле изначально и оставлены, чтобы продемонстрировать, что их следует указывать при настройке):

  • "fixed acidity"
  • "volatile acidity" *
  • "citric acid"
  • "residual sugar" *
  • "chlorides" *
  • "free sulfur dioxide"
  • "total sulfur dioxide"
  • "density" *
  • "pH"
  • "sulphates"
  • "alcohol" *

Колонки, помеченные звездочками, настраиваются для квантования с помощью параметра Quantization (кавычки тут также нужны из-за того, что они использованы в исходном файле):

"alcohol"=0.1,"volatile acidity"=0.01,"chlorides"=0.005,"density"=0.0005,"residual sugar"=0.1

Квантование позволяет существенно уменьшить расход вычислительных ресурсов и снижает вероятность переобучения.

Качество вина определяется как рейтинг в диапазоне от 3 до 8 (теоретически, шкала продолжается в меньшую сторону (2 и 1), и потенциально вверх (9), но таких классов в выборке нет):

  • "quality"

Настроечный файл winered.set.


* adult.csv / adult-t.csv *

Обучающий и тестовый наборы данных для определения уровня дохода.

Предикторами выступают:

  • age - возраст, число;
  • workclass - категория;
  • fnlwgt - код местности (не используется, отбрасывается в настройке с помощью DropColumns);
  • education - категория (кастомизируется с помощью NumberEncodingFile, см. ниже);
  • education-num - образование как количество лет;
  • marital-status - категория;
  • occupation - категория;
  • relationship - категория;
  • race - категория;
  • sex - категория;
  • capital-gain - число;
  • capital-loss - число;
  • hours-per-week - число;
  • native-country - категория;

Упомянутые настройки категории образования хранятся в файле adult-edu.txt, который указан в параметре NumberEncodingFile.

education:Preschool=1,1st-4th=2,5th-6th=5,7th-8th=7,9th=9,10th=10,11th=11,12th=12,Prof-school=13,HS-grad=14,Some-college=15,Assoc-voc=16,Assoc-acdm=17,Bachelors=18,Masters=19,Doctorate=20

Все прочие категории обрабатываются автоматически (для них генерируются хэши, а хэши переводятся в порядковые номера согласно статистике использования меток в классах). Вся эта информация сохраняется в tre-файлах сгенерированных цепочек деревьев и проверяется наличие аналогичных меток во время эксплуатации готовых деревьев.

Целевой класс находится в колонке:

  • income - категория (<=50K, >50K)

Настроечный файл adult-edu.set.

Вот несколько результатов с этими настройками (вероятно, они могут быть улучшены). На обучающем наборе удалось получить точность в 91%.

Accuracy: 29898 of 32561 (91.82%), Brier score: 0.0558902
Confusion matrix:
[[24037,683]
 [1980,5861]]
Precision by class (true predictions):
[92.38959142099397,89.56295843520782]
Recall/selectivity by class (completeness):
[97.23705501618123,74.74811886239]
* Result of classification test *
Overall Loss (BCE): 0.20834051903308287, by class:  [0]=0.227118 [1]=0.189563
Best results: BCE=0.19582[34], Accuracy=91.93[34], Brier=0.0525124[23]

На тестовом наборе показатели ожидаемо хуже - 83% (не забывайте для тестовых выборок или при штатном запуске классификации/регрессии указывать название готового tre-файла в параметре TreeFileName).

Accuracy: 13583 of 16281 (83.43%), Brier score: 0.102971
Confusion matrix:
[[11507,928]
 [1770,2076]]
Precision by class (true predictions):
[86.6686751525194,69.10785619174435]
Recall/selectivity by class (completeness):
[92.53719340570969,53.97815912636506]
* Result of classification test *
Overall Loss (BCE): 0.5052371989237212, by class:  [0]=0.530153 [1]=0.480321
Best results: BCE=0.40007[23], Accuracy=83.61[36], Brier=0.0802523[3]

Для сравнения из Интернета были взяты результаты для этого же набора данных, полученные с помощью sklearn.

CART 0.812 (0.005)
SVM 0.837 (0.005)
BAG 0.852 (0.004)
RF 0.849 (0.004)
GBM 0.863 (0.004)

Резюме: MQL-скрипт показывает приемлемое качество.

Исходные коды не публикуются - там требуется существенный рефакторинг.

Текущая версия - экспериментальная, возможно с большим количеством ошибок. Надежные результаты не гарантированы.


Файлы:
ReClObBT.zip  921 kb
Поделитесь с друзьями: