Scippy

SCIP

Solving Constraint Integer Programs

bandit_ucb.c
Go to the documentation of this file.
1/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2/* */
3/* This file is part of the program and library */
4/* SCIP --- Solving Constraint Integer Programs */
5/* */
6/* Copyright (c) 2002-2024 Zuse Institute Berlin (ZIB) */
7/* */
8/* Licensed under the Apache License, Version 2.0 (the "License"); */
9/* you may not use this file except in compliance with the License. */
10/* You may obtain a copy of the License at */
11/* */
12/* http://www.apache.org/licenses/LICENSE-2.0 */
13/* */
14/* Unless required by applicable law or agreed to in writing, software */
15/* distributed under the License is distributed on an "AS IS" BASIS, */
16/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17/* See the License for the specific language governing permissions and */
18/* limitations under the License. */
19/* */
20/* You should have received a copy of the Apache-2.0 license */
21/* along with SCIP; see the file LICENSE. If not visit scipopt.org. */
22/* */
23/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24
25/**@file bandit_ucb.c
26 * @ingroup OTHER_CFILES
27 * @brief methods for UCB bandit selection
28 * @author Gregor Hendel
29 */
30
31/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32
33#include "scip/bandit.h"
34#include "scip/bandit_ucb.h"
35#include "scip/pub_bandit.h"
36#include "scip/pub_message.h"
37#include "scip/pub_misc.h"
38#include "scip/pub_misc_sort.h"
39#include "scip/scip_bandit.h"
40#include "scip/scip_mem.h"
42
43
44#define BANDIT_NAME "ucb"
45#define NUMEPS 1e-6
46
47/*
48 * Data structures
49 */
50
51/** implementation specific data of UCB bandit algorithm */
52struct SCIP_BanditData
53{
54 int nselections; /**< counter for the number of selections */
55 int* counter; /**< array of counters how often every action has been chosen */
56 int* startperm; /**< indices for starting permutation */
57 SCIP_Real* meanscores; /**< array of average scores for the actions */
58 SCIP_Real alpha; /**< parameter to increase confidence width */
59};
60
61
62/*
63 * Local methods
64 */
65
66/** data reset method */
67static
69 BMS_BUFMEM* bufmem, /**< buffer memory */
70 SCIP_BANDIT* ucb, /**< ucb bandit algorithm */
71 SCIP_BANDITDATA* banditdata, /**< UCB bandit data structure */
72 SCIP_Real* priorities, /**< priorities for start permutation, or NULL */
73 int nactions /**< number of actions */
74 )
75{
76 int i;
77 SCIP_RANDNUMGEN* rng;
78
79 assert(bufmem != NULL);
80 assert(ucb != NULL);
81 assert(nactions > 0);
82
83 /* clear counters and scores */
84 BMSclearMemoryArray(banditdata->counter, nactions);
85 BMSclearMemoryArray(banditdata->meanscores, nactions);
86 banditdata->nselections = 0;
87
88 rng = SCIPbanditGetRandnumgen(ucb);
89 assert(rng != NULL);
90
91 /* initialize start permutation as identity */
92 for( i = 0; i < nactions; ++i )
93 banditdata->startperm[i] = i;
94
95 /* prepare the start permutation in decreasing order of priority */
96 if( priorities != NULL )
97 {
98 SCIP_Real* prioritycopy;
99
100 SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
101
102 /* randomly wiggle priorities a little bit to make them unique */
103 for( i = 0; i < nactions; ++i )
104 prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
105
106 SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
107
108 BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
109 }
110 else
111 {
112 /* use a random start permutation */
113 SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
114 }
115
116 return SCIP_OKAY;
117}
118
119
120/*
121 * Callback methods of bandit algorithm
122 */
123
124/** callback to free bandit specific data structures */
125SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
126{ /*lint --e{715}*/
127 SCIP_BANDITDATA* banditdata;
128 int nactions;
129 assert(bandit != NULL);
130
131 banditdata = SCIPbanditGetData(bandit);
132 assert(banditdata != NULL);
133 nactions = SCIPbanditGetNActions(bandit);
134
135 BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
136 BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
137 BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
138 BMSfreeBlockMemory(blkmem, &banditdata);
139
140 SCIPbanditSetData(bandit, NULL);
141
142 return SCIP_OKAY;
143}
144
145/** selection callback for bandit selector */
146SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
147{ /*lint --e{715}*/
148 SCIP_BANDITDATA* banditdata;
149 int nactions;
150 int* counter;
151
152 assert(bandit != NULL);
153 assert(selection != NULL);
154
155 banditdata = SCIPbanditGetData(bandit);
156 assert(banditdata != NULL);
157 nactions = SCIPbanditGetNActions(bandit);
158
159 counter = banditdata->counter;
160 /* select the next uninitialized action from the start permutation */
161 if( banditdata->nselections < nactions )
162 {
163 *selection = banditdata->startperm[banditdata->nselections];
164 assert(counter[*selection] == 0);
165 }
166 else
167 {
168 /* select the action with the highest upper confidence bound */
169 SCIP_Real* meanscores;
170 SCIP_Real widthfactor;
171 SCIP_Real maxucb;
172 int i;
174 meanscores = banditdata->meanscores;
175
176 assert(rng != NULL);
177 assert(meanscores != NULL);
178
179 /* compute the confidence width factor that is common for all actions */
180 /* cppcheck-suppress unpreciseMathCall */
181 widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
182 widthfactor = sqrt(widthfactor);
183 maxucb = -1.0;
184
185 /* loop over the actions and determine the maximum upper confidence bound.
186 * The upper confidence bound of an action is the sum of its mean score
187 * plus a confidence term that decreases with increasing number of observations of
188 * this action.
189 */
190 for( i = 0; i < nactions; ++i )
191 {
192 SCIP_Real uppercb;
193 SCIP_Real rootcount;
194 assert(counter[i] > 0);
195
196 /* compute the upper confidence bound for action i */
197 uppercb = meanscores[i];
198 rootcount = sqrt((SCIP_Real)counter[i]);
199 uppercb += widthfactor / rootcount;
200 assert(uppercb > 0);
201
202 /* update maximum, breaking ties uniformly at random */
203 if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
204 {
205 maxucb = uppercb;
206 *selection = i;
207 }
208 }
209 }
210
211 assert(*selection >= 0);
212 assert(*selection < nactions);
213
214 return SCIP_OKAY;
215}
216
217/** update callback for bandit algorithm */
218SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
219{ /*lint --e{715}*/
220 SCIP_BANDITDATA* banditdata;
221 SCIP_Real delta;
222
223 assert(bandit != NULL);
224
225 banditdata = SCIPbanditGetData(bandit);
226 assert(banditdata != NULL);
227 assert(selection >= 0);
228 assert(selection < SCIPbanditGetNActions(bandit));
229
230 /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
231 delta = score - banditdata->meanscores[selection];
232 ++banditdata->counter[selection];
233 banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
234
235 banditdata->nselections++;
236
237 return SCIP_OKAY;
238}
239
240/** reset callback for bandit algorithm */
241SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
242{ /*lint --e{715}*/
243 SCIP_BANDITDATA* banditdata;
244 int nactions;
245
246 assert(bufmem != NULL);
247 assert(bandit != NULL);
248
249 banditdata = SCIPbanditGetData(bandit);
250 assert(banditdata != NULL);
251 nactions = SCIPbanditGetNActions(bandit);
252
253 /* call the data reset for the given priorities */
254 SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
255
256 return SCIP_OKAY;
257}
258
259/*
260 * bandit algorithm specific interface methods
261 */
262
263/** returns the upper confidence bound of a selected action */
265 SCIP_BANDIT* ucb, /**< UCB bandit algorithm */
266 int action /**< index of the queried action */
267 )
268{
269 SCIP_Real uppercb;
270 SCIP_BANDITDATA* banditdata;
271 int nactions;
272
273 assert(ucb != NULL);
274 banditdata = SCIPbanditGetData(ucb);
275 nactions = SCIPbanditGetNActions(ucb);
276 assert(action < nactions);
277
278 /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
279 if( banditdata->nselections < nactions )
280 return 1.0;
281
282 /* the bandit algorithm must have picked every action once */
283 assert(banditdata->counter[action] > 0);
284 uppercb = banditdata->meanscores[action];
285
286 /* cppcheck-suppress unpreciseMathCall */
287 uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
288
289 return uppercb;
290}
291
292/** return start permutation of the UCB bandit algorithm */
294 SCIP_BANDIT* ucb /**< UCB bandit algorithm */
295 )
296{
297 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
298
299 assert(banditdata != NULL);
300
301 return banditdata->startperm;
302}
303
304/** internal method to create and reset UCB bandit algorithm */
306 BMS_BLKMEM* blkmem, /**< block memory */
307 BMS_BUFMEM* bufmem, /**< buffer memory */
308 SCIP_BANDITVTABLE* vtable, /**< virtual function table for UCB bandit algorithm */
309 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
310 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
311 SCIP_Real alpha, /**< parameter to increase confidence width */
312 int nactions, /**< the positive number of actions for this bandit algorithm */
313 unsigned int initseed /**< initial random seed */
314 )
315{
316 SCIP_BANDITDATA* banditdata;
317
318 if( alpha < 0.0 )
319 {
320 SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
321 return SCIP_INVALIDDATA;
322 }
323
324 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
325 assert(banditdata != NULL);
326
327 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
328 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
329 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
330
331 banditdata->alpha = alpha;
332
333 SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
334
335 return SCIP_OKAY;
336}
337
338/** create and reset UCB bandit algorithm */
340 SCIP* scip, /**< SCIP data structure */
341 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
342 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
343 SCIP_Real alpha, /**< parameter to increase confidence width */
344 int nactions, /**< the positive number of actions for this bandit algorithm */
345 unsigned int initseed /**< initial random number seed */
346 )
347{
348 SCIP_BANDITVTABLE* vtable;
349
351 if( vtable == NULL )
352 {
353 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
354 return SCIP_INVALIDDATA;
355 }
356
358 priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
359
360 return SCIP_OKAY;
361}
362
363/** include virtual function table for UCB bandit algorithms */
365 SCIP* scip /**< SCIP data structure */
366 )
367{
368 SCIP_BANDITVTABLE* vtable;
369
371 SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
372 assert(vtable != NULL);
373
374 return SCIP_OKAY;
375}
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:200
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:42
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:190
internal methods for bandit algorithms
static SCIP_RETCODE dataReset(BMS_BUFMEM *bufmem, SCIP_BANDIT *ucb, SCIP_BANDITDATA *banditdata, SCIP_Real *priorities, int nactions)
Definition: bandit_ucb.c:68
SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
Definition: bandit_ucb.c:125
#define NUMEPS
Definition: bandit_ucb.c:45
SCIP_RETCODE SCIPincludeBanditvtableUcb(SCIP *scip)
Definition: bandit_ucb.c:364
SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
Definition: bandit_ucb.c:146
#define BANDIT_NAME
Definition: bandit_ucb.c:44
SCIP_RETCODE SCIPbanditCreateUcb(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:305
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
Definition: bandit_ucb.c:218
SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
Definition: bandit_ucb.c:241
internal methods for UCB bandit algorithm
#define NULL
Definition: def.h:266
#define LOG1P(x)
Definition: def.h:221
#define SCIP_ALLOC(x)
Definition: def.h:384
#define SCIP_Real
Definition: def.h:172
#define EPSEQ(x, y, eps)
Definition: def.h:197
#define EPSGT(x, y, eps)
Definition: def.h:200
#define SCIP_CALL(x)
Definition: def.h:373
void SCIPrandomPermuteIntArray(SCIP_RANDNUMGEN *randnumgen, int *array, int begin, int end)
Definition: misc.c:10152
int * SCIPgetStartPermutationUcb(SCIP_BANDIT *ucb)
Definition: bandit_ucb.c:293
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:303
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:293
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:80
SCIP_Real SCIPgetConfidenceBoundUcb(SCIP_BANDIT *ucb, int action)
Definition: bandit_ucb.c:264
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:48
SCIP_RETCODE SCIPcreateBanditUcb(SCIP *scip, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:339
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip_mem.c:72
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:10133
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
void SCIPsortDownRealInt(SCIP_Real *realarray, int *intarray, int len)
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:465
#define BMSduplicateBufferMemoryArray(mem, ptr, source, num)
Definition: memory.h:737
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:451
#define BMSfreeBufferMemoryArray(mem, ptr)
Definition: memory.h:742
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:454
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:467
#define BMSclearMemoryArray(ptr, num)
Definition: memory.h:130
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:437
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip_mem.c:57
public methods for bandit algorithms
public methods for message output
#define SCIPerrorMessage
Definition: pub_message.h:64
public data structures and miscellaneous methods
methods for sorting joint arrays of various types
public methods for bandit algorithms
public methods for memory management
public methods for random numbers
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:56
@ SCIP_INVALIDDATA
Definition: type_retcode.h:52
@ SCIP_OKAY
Definition: type_retcode.h:42
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:63