Scippy

SCIP

Solving Constraint Integer Programs

bandit_ucb.c
Go to the documentation of this file.
1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (c) 2002-2023 Zuse Institute Berlin (ZIB) */
7 /* */
8 /* Licensed under the Apache License, Version 2.0 (the "License"); */
9 /* you may not use this file except in compliance with the License. */
10 /* You may obtain a copy of the License at */
11 /* */
12 /* http://www.apache.org/licenses/LICENSE-2.0 */
13 /* */
14 /* Unless required by applicable law or agreed to in writing, software */
15 /* distributed under the License is distributed on an "AS IS" BASIS, */
16 /* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17 /* See the License for the specific language governing permissions and */
18 /* limitations under the License. */
19 /* */
20 /* You should have received a copy of the Apache-2.0 license */
21 /* along with SCIP; see the file LICENSE. If not visit scipopt.org. */
22 /* */
23 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24 
25 /**@file bandit_ucb.c
26  * @ingroup OTHER_CFILES
27  * @brief methods for UCB bandit selection
28  * @author Gregor Hendel
29  */
30 
31 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32 
33 #include "scip/bandit.h"
34 #include "scip/bandit_ucb.h"
35 #include "scip/pub_bandit.h"
36 #include "scip/pub_message.h"
37 #include "scip/pub_misc.h"
38 #include "scip/pub_misc_sort.h"
39 #include "scip/scip_bandit.h"
40 #include "scip/scip_mem.h"
41 #include "scip/scip_randnumgen.h"
42 
43 
44 #define BANDIT_NAME "ucb"
45 #define NUMEPS 1e-6
46 
47 /*
48  * Data structures
49  */
50 
51 /** implementation specific data of UCB bandit algorithm */
52 struct SCIP_BanditData
53 {
54  int nselections; /**< counter for the number of selections */
55  int* counter; /**< array of counters how often every action has been chosen */
56  int* startperm; /**< indices for starting permutation */
57  SCIP_Real* meanscores; /**< array of average scores for the actions */
58  SCIP_Real alpha; /**< parameter to increase confidence width */
59 };
60 
61 
62 /*
63  * Local methods
64  */
65 
66 /** data reset method */
67 static
69  BMS_BUFMEM* bufmem, /**< buffer memory */
70  SCIP_BANDIT* ucb, /**< ucb bandit algorithm */
71  SCIP_BANDITDATA* banditdata, /**< UCB bandit data structure */
72  SCIP_Real* priorities, /**< priorities for start permutation, or NULL */
73  int nactions /**< number of actions */
74  )
75 {
76  int i;
77  SCIP_RANDNUMGEN* rng;
78 
79  assert(bufmem != NULL);
80  assert(ucb != NULL);
81  assert(nactions > 0);
82 
83  /* clear counters and scores */
84  BMSclearMemoryArray(banditdata->counter, nactions);
85  BMSclearMemoryArray(banditdata->meanscores, nactions);
86  banditdata->nselections = 0;
87 
88  rng = SCIPbanditGetRandnumgen(ucb);
89  assert(rng != NULL);
90 
91  /* initialize start permutation as identity */
92  for( i = 0; i < nactions; ++i )
93  banditdata->startperm[i] = i;
94 
95  /* prepare the start permutation in decreasing order of priority */
96  if( priorities != NULL )
97  {
98  SCIP_Real* prioritycopy;
99 
100  SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
101 
102  /* randomly wiggle priorities a little bit to make them unique */
103  for( i = 0; i < nactions; ++i )
104  prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
105 
106  SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
107 
108  BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
109  }
110  else
111  {
112  /* use a random start permutation */
113  SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
114  }
115 
116  return SCIP_OKAY;
117 }
118 
119 
120 /*
121  * Callback methods of bandit algorithm
122  */
123 
124 /** callback to free bandit specific data structures */
125 SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
126 { /*lint --e{715}*/
127  SCIP_BANDITDATA* banditdata;
128  int nactions;
129  assert(bandit != NULL);
130 
131  banditdata = SCIPbanditGetData(bandit);
132  assert(banditdata != NULL);
133  nactions = SCIPbanditGetNActions(bandit);
134 
135  BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
136  BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
137  BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
138  BMSfreeBlockMemory(blkmem, &banditdata);
139 
140  SCIPbanditSetData(bandit, NULL);
141 
142  return SCIP_OKAY;
143 }
144 
145 /** selection callback for bandit selector */
146 SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
147 { /*lint --e{715}*/
148  SCIP_BANDITDATA* banditdata;
149  int nactions;
150  int* counter;
151 
152  assert(bandit != NULL);
153  assert(selection != NULL);
154 
155  banditdata = SCIPbanditGetData(bandit);
156  assert(banditdata != NULL);
157  nactions = SCIPbanditGetNActions(bandit);
158 
159  counter = banditdata->counter;
160  /* select the next uninitialized action from the start permutation */
161  if( banditdata->nselections < nactions )
162  {
163  *selection = banditdata->startperm[banditdata->nselections];
164  assert(counter[*selection] == 0);
165  }
166  else
167  {
168  /* select the action with the highest upper confidence bound */
169  SCIP_Real* meanscores;
170  SCIP_Real widthfactor;
171  SCIP_Real maxucb;
172  int i;
174  meanscores = banditdata->meanscores;
175 
176  assert(rng != NULL);
177  assert(meanscores != NULL);
178 
179  /* compute the confidence width factor that is common for all actions */
180  /* cppcheck-suppress unpreciseMathCall */
181  widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
182  widthfactor = sqrt(widthfactor);
183  maxucb = -1.0;
184 
185  /* loop over the actions and determine the maximum upper confidence bound.
186  * The upper confidence bound of an action is the sum of its mean score
187  * plus a confidence term that decreases with increasing number of observations of
188  * this action.
189  */
190  for( i = 0; i < nactions; ++i )
191  {
192  SCIP_Real uppercb;
193  SCIP_Real rootcount;
194  assert(counter[i] > 0);
195 
196  /* compute the upper confidence bound for action i */
197  uppercb = meanscores[i];
198  rootcount = sqrt((SCIP_Real)counter[i]);
199  uppercb += widthfactor / rootcount;
200  assert(uppercb > 0);
201 
202  /* update maximum, breaking ties uniformly at random */
203  if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
204  {
205  maxucb = uppercb;
206  *selection = i;
207  }
208  }
209  }
210 
211  assert(*selection >= 0);
212  assert(*selection < nactions);
213 
214  return SCIP_OKAY;
215 }
216 
217 /** update callback for bandit algorithm */
218 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
219 { /*lint --e{715}*/
220  SCIP_BANDITDATA* banditdata;
221  SCIP_Real delta;
222 
223  assert(bandit != NULL);
224 
225  banditdata = SCIPbanditGetData(bandit);
226  assert(banditdata != NULL);
227  assert(selection >= 0);
228  assert(selection < SCIPbanditGetNActions(bandit));
229 
230  /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
231  delta = score - banditdata->meanscores[selection];
232  ++banditdata->counter[selection];
233  banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
234 
235  banditdata->nselections++;
236 
237  return SCIP_OKAY;
238 }
239 
240 /** reset callback for bandit algorithm */
241 SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
242 { /*lint --e{715}*/
243  SCIP_BANDITDATA* banditdata;
244  int nactions;
245 
246  assert(bufmem != NULL);
247  assert(bandit != NULL);
248 
249  banditdata = SCIPbanditGetData(bandit);
250  assert(banditdata != NULL);
251  nactions = SCIPbanditGetNActions(bandit);
252 
253  /* call the data reset for the given priorities */
254  SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
255 
256  return SCIP_OKAY;
257 }
258 
259 /*
260  * bandit algorithm specific interface methods
261  */
262 
263 /** returns the upper confidence bound of a selected action */
265  SCIP_BANDIT* ucb, /**< UCB bandit algorithm */
266  int action /**< index of the queried action */
267  )
268 {
269  SCIP_Real uppercb;
270  SCIP_BANDITDATA* banditdata;
271  int nactions;
272 
273  assert(ucb != NULL);
274  banditdata = SCIPbanditGetData(ucb);
275  nactions = SCIPbanditGetNActions(ucb);
276  assert(action < nactions);
277 
278  /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
279  if( banditdata->nselections < nactions )
280  return 1.0;
281 
282  /* the bandit algorithm must have picked every action once */
283  assert(banditdata->counter[action] > 0);
284  uppercb = banditdata->meanscores[action];
285 
286  /* cppcheck-suppress unpreciseMathCall */
287  uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
288 
289  return uppercb;
290 }
291 
292 /** return start permutation of the UCB bandit algorithm */
294  SCIP_BANDIT* ucb /**< UCB bandit algorithm */
295  )
296 {
297  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
298 
299  assert(banditdata != NULL);
300 
301  return banditdata->startperm;
302 }
303 
304 /** internal method to create and reset UCB bandit algorithm */
306  BMS_BLKMEM* blkmem, /**< block memory */
307  BMS_BUFMEM* bufmem, /**< buffer memory */
308  SCIP_BANDITVTABLE* vtable, /**< virtual function table for UCB bandit algorithm */
309  SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
310  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
311  SCIP_Real alpha, /**< parameter to increase confidence width */
312  int nactions, /**< the positive number of actions for this bandit algorithm */
313  unsigned int initseed /**< initial random seed */
314  )
315 {
316  SCIP_BANDITDATA* banditdata;
317 
318  if( alpha < 0.0 )
319  {
320  SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
321  return SCIP_INVALIDDATA;
322  }
323 
324  SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
325  assert(banditdata != NULL);
326 
327  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
328  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
329  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
330 
331  banditdata->alpha = alpha;
332 
333  SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
334 
335  return SCIP_OKAY;
336 }
337 
338 /** create and reset UCB bandit algorithm */
340  SCIP* scip, /**< SCIP data structure */
341  SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
342  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
343  SCIP_Real alpha, /**< parameter to increase confidence width */
344  int nactions, /**< the positive number of actions for this bandit algorithm */
345  unsigned int initseed /**< initial random number seed */
346  )
347 {
348  SCIP_BANDITVTABLE* vtable;
349 
350  vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
351  if( vtable == NULL )
352  {
353  SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
354  return SCIP_INVALIDDATA;
355  }
356 
357  SCIP_CALL( SCIPbanditCreateUcb(SCIPblkmem(scip), SCIPbuffer(scip), vtable, ucb,
358  priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
359 
360  return SCIP_OKAY;
361 }
362 
363 /** include virtual function table for UCB bandit algorithms */
365  SCIP* scip /**< SCIP data structure */
366  )
367 {
368  SCIP_BANDITVTABLE* vtable;
369 
371  SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
372  assert(vtable != NULL);
373 
374  return SCIP_OKAY;
375 }
SCIP_RETCODE SCIPcreateBanditUcb(SCIP *scip, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:339
public methods for memory management
#define EPSEQ(x, y, eps)
Definition: def.h:211
internal methods for bandit algorithms
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:63
void SCIPsortDownRealInt(SCIP_Real *realarray, int *intarray, int len)
#define BMSduplicateBufferMemoryArray(mem, ptr, source, num)
Definition: memory.h:739
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:190
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip_mem.c:72
#define SCIPerrorMessage
Definition: pub_message.h:64
internal methods for UCB bandit algorithm
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
Definition: bandit_ucb.c:218
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip_mem.c:57
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:80
#define NULL
Definition: lpi_spx1.cpp:164
SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
Definition: bandit_ucb.c:241
#define SCIP_CALL(x)
Definition: def.h:394
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:200
#define NUMEPS
Definition: bandit_ucb.c:45
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:467
public data structures and miscellaneous methods
SCIP_RETCODE SCIPincludeBanditvtableUcb(SCIP *scip)
Definition: bandit_ucb.c:364
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:456
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:48
SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
Definition: bandit_ucb.c:146
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:469
void SCIPrandomPermuteIntArray(SCIP_RANDNUMGEN *randnumgen, int *array, int begin, int end)
Definition: misc.c:10060
int * SCIPgetStartPermutationUcb(SCIP_BANDIT *ucb)
Definition: bandit_ucb.c:293
public methods for bandit algorithms
#define BANDIT_NAME
Definition: bandit_ucb.c:44
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:10041
public methods for bandit algorithms
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:56
methods for sorting joint arrays of various types
#define EPSGT(x, y, eps)
Definition: def.h:214
public methods for random numbers
public methods for message output
SCIP_RETCODE SCIPbanditCreateUcb(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:305
#define SCIP_Real
Definition: def.h:186
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:303
static SCIP_RETCODE dataReset(BMS_BUFMEM *bufmem, SCIP_BANDIT *ucb, SCIP_BANDITDATA *banditdata, SCIP_Real *priorities, int nactions)
Definition: bandit_ucb.c:68
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:293
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:453
#define BMSclearMemoryArray(ptr, num)
Definition: memory.h:132
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:439
#define SCIP_ALLOC(x)
Definition: def.h:405
SCIP_Real SCIPgetConfidenceBoundUcb(SCIP_BANDIT *ucb, int action)
Definition: bandit_ucb.c:264
#define BMSfreeBufferMemoryArray(mem, ptr)
Definition: memory.h:744
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:42
SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
Definition: bandit_ucb.c:125