Scippy

SCIP

Solving Constraint Integer Programs

bandit_ucb.c
Go to the documentation of this file.
1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (C) 2002-2020 Konrad-Zuse-Zentrum */
7 /* fuer Informationstechnik Berlin */
8 /* */
9 /* SCIP is distributed under the terms of the ZIB Academic License. */
10 /* */
11 /* You should have received a copy of the ZIB Academic License */
12 /* along with SCIP; see the file COPYING. If not visit scipopt.org. */
13 /* */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15 
16 /**@file bandit_ucb.c
17  * @ingroup OTHER_CFILES
18  * @brief methods for UCB bandit selection
19  * @author Gregor Hendel
20  */
21 
22 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
23 
24 #include "scip/bandit.h"
25 #include "scip/bandit_ucb.h"
26 #include "scip/pub_bandit.h"
27 #include "scip/pub_message.h"
28 #include "scip/pub_misc.h"
29 #include "scip/pub_misc_sort.h"
30 #include "scip/scip_bandit.h"
31 #include "scip/scip_mem.h"
32 #include "scip/scip_randnumgen.h"
33 
34 
35 #define BANDIT_NAME "ucb"
36 #define NUMEPS 1e-6
37 
38 /*
39  * Data structures
40  */
41 
42 /** implementation specific data of UCB bandit algorithm */
43 struct SCIP_BanditData
44 {
45  int nselections; /**< counter for the number of selections */
46  int* counter; /**< array of counters how often every action has been chosen */
47  int* startperm; /**< indices for starting permutation */
48  SCIP_Real* meanscores; /**< array of average scores for the actions */
49  SCIP_Real alpha; /**< parameter to increase confidence width */
50 };
51 
52 
53 /*
54  * Local methods
55  */
56 
57 /** data reset method */
58 static
60  BMS_BUFMEM* bufmem, /**< buffer memory */
61  SCIP_BANDIT* ucb, /**< ucb bandit algorithm */
62  SCIP_BANDITDATA* banditdata, /**< UCB bandit data structure */
63  SCIP_Real* priorities, /**< priorities for start permutation, or NULL */
64  int nactions /**< number of actions */
65  )
66 {
67  int i;
68  SCIP_RANDNUMGEN* rng;
69 
70  assert(bufmem != NULL);
71  assert(ucb != NULL);
72  assert(nactions > 0);
73 
74  /* clear counters and scores */
75  BMSclearMemoryArray(banditdata->counter, nactions);
76  BMSclearMemoryArray(banditdata->meanscores, nactions);
77  banditdata->nselections = 0;
78 
79  rng = SCIPbanditGetRandnumgen(ucb);
80  assert(rng != NULL);
81 
82  /* initialize start permutation as identity */
83  for( i = 0; i < nactions; ++i )
84  banditdata->startperm[i] = i;
85 
86  /* prepare the start permutation in decreasing order of priority */
87  if( priorities != NULL )
88  {
89  SCIP_Real* prioritycopy;
90 
91  SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
92 
93  /* randomly wiggle priorities a little bit to make them unique */
94  for( i = 0; i < nactions; ++i )
95  prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
96 
97  SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
98 
99  BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
100  }
101  else
102  {
103  /* use a random start permutation */
104  SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
105  }
106 
107  return SCIP_OKAY;
108 }
109 
110 
111 /*
112  * Callback methods of bandit algorithm
113  */
114 
115 /** callback to free bandit specific data structures */
116 SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
117 { /*lint --e{715}*/
118  SCIP_BANDITDATA* banditdata;
119  int nactions;
120  assert(bandit != NULL);
121 
122  banditdata = SCIPbanditGetData(bandit);
123  assert(banditdata != NULL);
124  nactions = SCIPbanditGetNActions(bandit);
125 
126  BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
127  BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
128  BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
129  BMSfreeBlockMemory(blkmem, &banditdata);
130 
131  SCIPbanditSetData(bandit, NULL);
132 
133  return SCIP_OKAY;
134 }
135 
136 /** selection callback for bandit selector */
137 SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
138 { /*lint --e{715}*/
139  SCIP_BANDITDATA* banditdata;
140  int nactions;
141  int* counter;
142 
143  assert(bandit != NULL);
144  assert(selection != NULL);
145 
146  banditdata = SCIPbanditGetData(bandit);
147  assert(banditdata != NULL);
148  nactions = SCIPbanditGetNActions(bandit);
149 
150  counter = banditdata->counter;
151  /* select the next uninitialized action from the start permutation */
152  if( banditdata->nselections < nactions )
153  {
154  *selection = banditdata->startperm[banditdata->nselections];
155  assert(counter[*selection] == 0);
156  }
157  else
158  {
159  /* select the action with the highest upper confidence bound */
160  SCIP_Real* meanscores;
161  SCIP_Real widthfactor;
162  SCIP_Real maxucb;
163  int i;
165  meanscores = banditdata->meanscores;
166 
167  assert(rng != NULL);
168  assert(meanscores != NULL);
169 
170  /* compute the confidence width factor that is common for all actions */
171  /* cppcheck-suppress unpreciseMathCall */
172  widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
173  widthfactor = sqrt(widthfactor);
174  maxucb = -1.0;
175 
176  /* loop over the actions and determine the maximum upper confidence bound.
177  * The upper confidence bound of an action is the sum of its mean score
178  * plus a confidence term that decreases with increasing number of observations of
179  * this action.
180  */
181  for( i = 0; i < nactions; ++i )
182  {
183  SCIP_Real uppercb;
184  SCIP_Real rootcount;
185  assert(counter[i] > 0);
186 
187  /* compute the upper confidence bound for action i */
188  uppercb = meanscores[i];
189  rootcount = sqrt((SCIP_Real)counter[i]);
190  uppercb += widthfactor / rootcount;
191  assert(uppercb > 0);
192 
193  /* update maximum, breaking ties uniformly at random */
194  if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
195  {
196  maxucb = uppercb;
197  *selection = i;
198  }
199  }
200  }
201 
202  assert(*selection >= 0);
203  assert(*selection < nactions);
204 
205  return SCIP_OKAY;
206 }
207 
208 /** update callback for bandit algorithm */
209 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
210 { /*lint --e{715}*/
211  SCIP_BANDITDATA* banditdata;
212  SCIP_Real delta;
213 
214  assert(bandit != NULL);
215 
216  banditdata = SCIPbanditGetData(bandit);
217  assert(banditdata != NULL);
218  assert(selection >= 0);
219  assert(selection < SCIPbanditGetNActions(bandit));
220 
221  /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
222  delta = score - banditdata->meanscores[selection];
223  ++banditdata->counter[selection];
224  banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
225 
226  banditdata->nselections++;
227 
228  return SCIP_OKAY;
229 }
230 
231 /** reset callback for bandit algorithm */
232 SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
233 { /*lint --e{715}*/
234  SCIP_BANDITDATA* banditdata;
235  int nactions;
236 
237  assert(bufmem != NULL);
238  assert(bandit != NULL);
239 
240  banditdata = SCIPbanditGetData(bandit);
241  assert(banditdata != NULL);
242  nactions = SCIPbanditGetNActions(bandit);
243 
244  /* call the data reset for the given priorities */
245  SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
246 
247  return SCIP_OKAY;
248 }
249 
250 /*
251  * bandit algorithm specific interface methods
252  */
253 
254 /** returns the upper confidence bound of a selected action */
256  SCIP_BANDIT* ucb, /**< UCB bandit algorithm */
257  int action /**< index of the queried action */
258  )
259 {
260  SCIP_Real uppercb;
261  SCIP_BANDITDATA* banditdata;
262  int nactions;
263 
264  assert(ucb != NULL);
265  banditdata = SCIPbanditGetData(ucb);
266  nactions = SCIPbanditGetNActions(ucb);
267  assert(action < nactions);
268 
269  /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
270  if( banditdata->nselections < nactions )
271  return 1.0;
272 
273  /* the bandit algorithm must have picked every action once */
274  assert(banditdata->counter[action] > 0);
275  uppercb = banditdata->meanscores[action];
276 
277  /* cppcheck-suppress unpreciseMathCall */
278  uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
279 
280  return uppercb;
281 }
282 
283 /** return start permutation of the UCB bandit algorithm */
285  SCIP_BANDIT* ucb /**< UCB bandit algorithm */
286  )
287 {
288  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
289 
290  assert(banditdata != NULL);
291 
292  return banditdata->startperm;
293 }
294 
295 /** internal method to create and reset UCB bandit algorithm */
297  BMS_BLKMEM* blkmem, /**< block memory */
298  BMS_BUFMEM* bufmem, /**< buffer memory */
299  SCIP_BANDITVTABLE* vtable, /**< virtual function table for UCB bandit algorithm */
300  SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
301  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
302  SCIP_Real alpha, /**< parameter to increase confidence width */
303  int nactions, /**< the positive number of actions for this bandit algorithm */
304  unsigned int initseed /**< initial random seed */
305  )
306 {
307  SCIP_BANDITDATA* banditdata;
308 
309  if( alpha < 0.0 )
310  {
311  SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
312  return SCIP_INVALIDDATA;
313  }
314 
315  SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
316  assert(banditdata != NULL);
317 
318  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
319  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
320  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
321 
322  banditdata->alpha = alpha;
323 
324  SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
325 
326  return SCIP_OKAY;
327 }
328 
329 /** create and reset UCB bandit algorithm */
331  SCIP* scip, /**< SCIP data structure */
332  SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
333  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
334  SCIP_Real alpha, /**< parameter to increase confidence width */
335  int nactions, /**< the positive number of actions for this bandit algorithm */
336  unsigned int initseed /**< initial random number seed */
337  )
338 {
339  SCIP_BANDITVTABLE* vtable;
340 
341  vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
342  if( vtable == NULL )
343  {
344  SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
345  return SCIP_INVALIDDATA;
346  }
347 
348  SCIP_CALL( SCIPbanditCreateUcb(SCIPblkmem(scip), SCIPbuffer(scip), vtable, ucb,
349  priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
350 
351  return SCIP_OKAY;
352 }
353 
354 /** include virtual function table for UCB bandit algorithms */
356  SCIP* scip /**< SCIP data structure */
357  )
358 {
359  SCIP_BANDITVTABLE* vtable;
360 
362  SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
363  assert(vtable != NULL);
364 
365  return SCIP_OKAY;
366 }
public methods for memory management
#define EPSEQ(x, y, eps)
Definition: def.h:188
internal methods for bandit algorithms
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:9967
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:54
#define BMSduplicateBufferMemoryArray(mem, ptr, source, num)
Definition: memory.h:718
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:181
int * SCIPgetStartPermutationUcb(SCIP_BANDIT *ucb)
Definition: bandit_ucb.c:284
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
SCIP_EXPORT void SCIPsortDownRealInt(SCIP_Real *realarray, int *intarray, int len)
#define SCIPerrorMessage
Definition: pub_message.h:55
internal methods for UCB bandit algorithm
SCIPInterval sqrt(const SCIPInterval &x)
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
Definition: bandit_ucb.c:209
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip_mem.c:48
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:71
#define NULL
Definition: lpi_spx1.cpp:155
SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
Definition: bandit_ucb.c:232
#define SCIP_CALL(x)
Definition: def.h:364
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:284
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:191
#define NUMEPS
Definition: bandit_ucb.c:36
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:456
public data structures and miscellaneous methods
SCIP_RETCODE SCIPincludeBanditvtableUcb(SCIP *scip)
Definition: bandit_ucb.c:355
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:445
SCIP_RETCODE SCIPcreateBanditUcb(SCIP *scip, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:330
SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
Definition: bandit_ucb.c:137
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:458
SCIP_Real SCIPgetConfidenceBoundUcb(SCIP_BANDIT *ucb, int action)
Definition: bandit_ucb.c:255
public methods for bandit algorithms
void SCIPrandomPermuteIntArray(SCIP_RANDNUMGEN *randnumgen, int *array, int begin, int end)
Definition: misc.c:9986
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip_mem.c:63
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:294
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:39
#define BANDIT_NAME
Definition: bandit_ucb.c:35
public methods for bandit algorithms
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:47
methods for sorting joint arrays of various types
#define EPSGT(x, y, eps)
Definition: def.h:191
public methods for random numbers
public methods for message output
SCIP_RETCODE SCIPbanditCreateUcb(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:296
#define SCIP_Real
Definition: def.h:163
static SCIP_RETCODE dataReset(BMS_BUFMEM *bufmem, SCIP_BANDIT *ucb, SCIP_BANDITDATA *banditdata, SCIP_Real *priorities, int nactions)
Definition: bandit_ucb.c:59
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:443
#define BMSclearMemoryArray(ptr, num)
Definition: memory.h:122
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:429
#define SCIP_ALLOC(x)
Definition: def.h:375
#define BMSfreeBufferMemoryArray(mem, ptr)
Definition: memory.h:723
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:33
SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
Definition: bandit_ucb.c:116