Scippy

SCIP

Solving Constraint Integer Programs

bandit_exp3.c
Go to the documentation of this file.
1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (C) 2002-2019 Konrad-Zuse-Zentrum */
7 /* fuer Informationstechnik Berlin */
8 /* */
9 /* SCIP is distributed under the terms of the ZIB Academic License. */
10 /* */
11 /* You should have received a copy of the ZIB Academic License */
12 /* along with SCIP; see the file COPYING. If not visit scip.zib.de. */
13 /* */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15 
16 /**@file bandit_exp3.c
17  * @brief methods for Exp.3 bandit selection
18  * @author Gregor Hendel
19  */
20 
21 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
22 
23 #include "scip/bandit.h"
24 #include "scip/bandit_exp3.h"
25 #include "scip/pub_bandit.h"
26 #include "scip/pub_message.h"
27 #include "scip/pub_misc.h"
28 #include "scip/scip_bandit.h"
29 #include "scip/scip_mem.h"
30 #include "scip/scip_randnumgen.h"
31 
32 #define BANDIT_NAME "exp3"
33 #define NUMTOL 1e-6
34 
35 /*
36  * Data structures
37  */
38 
39 /** implementation specific data of Exp.3 bandit algorithm */
40 struct SCIP_BanditData
41 {
42  SCIP_Real* weights; /**< exponential weight for each arm */
43  SCIP_Real weightsum; /**< the sum of all weights */
44  SCIP_Real gamma; /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
45  SCIP_Real beta; /**< gain offset between 0 and 1 at every observation */
46 };
47 
48 /*
49  * Local methods
50  */
51 
52 /*
53  * Callback methods of bandit algorithm
54  */
55 
56 /** callback to free bandit specific data structures */
57 SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
58 { /*lint --e{715}*/
59  SCIP_BANDITDATA* banditdata;
60  int nactions;
61  assert(bandit != NULL);
62 
63  banditdata = SCIPbanditGetData(bandit);
64  assert(banditdata != NULL);
65  nactions = SCIPbanditGetNActions(bandit);
66 
67  BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
68 
69  BMSfreeBlockMemory(blkmem, &banditdata);
70 
71  SCIPbanditSetData(bandit, NULL);
72 
73  return SCIP_OKAY;
74 }
75 
76 /** selection callback for bandit selector */
77 SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
78 { /*lint --e{715}*/
79  SCIP_BANDITDATA* banditdata;
80  SCIP_RANDNUMGEN* rng;
81  SCIP_Real randnr;
82  SCIP_Real psum;
83  SCIP_Real gammaoverk;
84  SCIP_Real oneminusgamma;
85  SCIP_Real* weights;
86  SCIP_Real weightsum;
87  int i;
88  int nactions;
89 
90  assert(bandit != NULL);
91  assert(selection != NULL);
92 
93  banditdata = SCIPbanditGetData(bandit);
94  assert(banditdata != NULL);
95  rng = SCIPbanditGetRandnumgen(bandit);
96  assert(rng != NULL);
97  nactions = SCIPbanditGetNActions(bandit);
98 
99  /* draw a random number between 0 and 1 */
100  randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
101 
102  /* initialize some local variables to speed up probability computations */
103  oneminusgamma = 1 - banditdata->gamma;
104  gammaoverk = banditdata->gamma / (SCIP_Real)nactions;
105  weightsum = banditdata->weightsum;
106  weights = banditdata->weights;
107  psum = 0.0;
108 
109  /* loop over probability distribution until rand is reached
110  * the loop terminates without looking at the last action,
111  * which is then selected automatically if the target probability
112  * is not reached earlier
113  */
114  for( i = 0; i < nactions - 1; ++i )
115  {
116  SCIP_Real prob;
117 
118  /* compute the probability for arm i as convex kombination of a uniform distribution and a weighted distribution */
119  prob = oneminusgamma * weights[i] / weightsum + gammaoverk;
120  psum += prob;
121 
122  /* break and select element if target probability is reached */
123  if( randnr <= psum )
124  break;
125  }
126 
127  /* select element i, which is the last action in case that the break statement hasn't been reached */
128  *selection = i;
129 
130  return SCIP_OKAY;
131 }
132 
133 /** update callback for bandit algorithm */
134 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
135 { /*lint --e{715}*/
136  SCIP_BANDITDATA* banditdata;
137  SCIP_Real eta;
138  SCIP_Real gainestim;
139  SCIP_Real beta;
140  SCIP_Real weightsum;
141  SCIP_Real newweightsum;
142  SCIP_Real* weights;
143  SCIP_Real oneminusgamma;
144  SCIP_Real gammaoverk;
145  int nactions;
146 
147  assert(bandit != NULL);
148 
149  banditdata = SCIPbanditGetData(bandit);
150  assert(banditdata != NULL);
151  nactions = SCIPbanditGetNActions(bandit);
152 
153  assert(selection >= 0);
154  assert(selection < nactions);
155 
156  /* the learning rate eta */
157  eta = 1.0 / (SCIP_Real)nactions;
158 
159  beta = banditdata->beta;
160  oneminusgamma = 1.0 - banditdata->gamma;
161  gammaoverk = banditdata->gamma * eta;
162  weights = banditdata->weights;
163  weightsum = banditdata->weightsum;
164  newweightsum = weightsum;
165 
166  /* if beta is zero, only the observation for the current arm needs an update */
167  if( EPSZ(beta, NUMTOL) )
168  {
169  SCIP_Real probai;
170  probai = oneminusgamma * weights[selection] / weightsum + gammaoverk;
171 
172  assert(probai > 0.0);
173 
174  gainestim = score / probai;
175  newweightsum -= weights[selection];
176  weights[selection] *= exp(eta * gainestim);
177  newweightsum += weights[selection];
178  }
179  else
180  {
181  int j;
182  newweightsum = 0.0;
183 
184  /* loop over all items and update their weights based on the influence of the beta parameter */
185  for( j = 0; j < nactions; ++j )
186  {
187  SCIP_Real probaj;
188  probaj = oneminusgamma * weights[j] / weightsum + gammaoverk;
189 
190  assert(probaj > 0.0);
191 
192  /* consider the score only for the chosen arm i, use constant beta offset otherwise */
193  if( j == selection )
194  gainestim = (score + beta) / probaj;
195  else
196  gainestim = beta / probaj;
197 
198  weights[j] *= exp(eta * gainestim);
199  newweightsum += weights[j];
200  }
201  }
202 
203  banditdata->weightsum = newweightsum;
204 
205  return SCIP_OKAY;
206 }
207 
208 /** reset callback for bandit algorithm */
209 SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
210 { /*lint --e{715}*/
211  SCIP_BANDITDATA* banditdata;
212  SCIP_Real* weights;
213  int nactions;
214  int i;
215 
216  assert(bandit != NULL);
217 
218  banditdata = SCIPbanditGetData(bandit);
219  assert(banditdata != NULL);
220  nactions = SCIPbanditGetNActions(bandit);
221  weights = banditdata->weights;
222 
223  assert(nactions > 0);
224 
225  banditdata->weightsum = (1.0 + NUMTOL) * (SCIP_Real)nactions;
226 
227  /* in case of priorities, weights are normalized to sum up to nactions */
228  if( priorities != NULL )
229  {
230  SCIP_Real normalization;
231  SCIP_Real priosum;
232  priosum = 0.0;
233 
234  /* compute sum of priorities */
235  for( i = 0; i < nactions; ++i )
236  {
237  assert(priorities[i] >= 0);
238  priosum += priorities[i];
239  }
240 
241  /* if there are positive priorities, normalize the weights */
242  if( priosum > 0.0 )
243  {
244  normalization = nactions / priosum;
245  for( i = 0; i < nactions; ++i )
246  weights[i] = (priorities[i] * normalization) + NUMTOL;
247  }
248  else
249  {
250  /* use uniform distribution in case of all priorities being 0.0 */
251  for( i = 0; i < nactions; ++i )
252  weights[i] = 1.0 + NUMTOL;
253  }
254  }
255  else
256  {
257  /* use uniform distribution in case of unspecified priorities */
258  for( i = 0; i < nactions; ++i )
259  weights[i] = 1.0 + NUMTOL;
260  }
261 
262  return SCIP_OKAY;
263 }
264 
265 
266 /*
267  * bandit algorithm specific interface methods
268  */
269 
270 /** direct bandit creation method for the core where no SCIP pointer is available */
272  BMS_BLKMEM* blkmem, /**< block memory data structure */
273  BMS_BUFMEM* bufmem, /**< buffer memory */
274  SCIP_BANDITVTABLE* vtable, /**< virtual function table for callback functions of Exp.3 */
275  SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
276  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
277  SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
278  SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
279  int nactions, /**< the positive number of actions for this bandit algorithm */
280  unsigned int initseed /**< initial random seed */
281  )
282 {
283  SCIP_BANDITDATA* banditdata;
284 
285  SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
286  assert(banditdata != NULL);
287 
288  banditdata->gamma = gammaparam;
289  banditdata->beta = beta;
290  assert(gammaparam >= 0 && gammaparam <= 1);
291  assert(beta >= 0 && beta <= 1);
292 
293  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
294 
295  SCIP_CALL( SCIPbanditCreate(exp3, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
296 
297  return SCIP_OKAY;
298 }
299 
300 /** creates and resets an Exp.3 bandit algorithm using \p scip pointer */
302  SCIP* scip, /**< SCIP data structure */
303  SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
304  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
305  SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
306  SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
307  int nactions, /**< the positive number of actions for this bandit algorithm */
308  unsigned int initseed /**< initial seed for random number generation */
309  )
310 {
311  SCIP_BANDITVTABLE* vtable;
312 
313  vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
314  if( vtable == NULL )
315  {
316  SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
317  return SCIP_INVALIDDATA;
318  }
319 
320  SCIP_CALL( SCIPbanditCreateExp3(SCIPblkmem(scip), SCIPbuffer(scip), vtable, exp3,
321  priorities, gammaparam, beta, nactions, SCIPinitializeRandomSeed(scip, (int)(initseed % INT_MAX))) );
322 
323  return SCIP_OKAY;
324 }
325 
326 /** set gamma parameter of Exp.3 bandit algorithm to increase weight of uniform distribution */
328  SCIP_BANDIT* exp3, /**< bandit algorithm */
329  SCIP_Real gammaparam /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
330  )
331 {
332  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
333 
334  assert(gammaparam >= 0 && gammaparam <= 1);
335 
336  banditdata->gamma = gammaparam;
337 }
338 
339 /** set beta parameter of Exp.3 bandit algorithm to increase gain offset for actions that were not played */
341  SCIP_BANDIT* exp3, /**< bandit algorithm */
342  SCIP_Real beta /**< gain offset between 0 and 1 at every observation */
343  )
344 {
345  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
346 
347  assert(beta >= 0 && beta <= 1);
348 
349  banditdata->beta = beta;
350 }
351 
352 /** returns probability to play an action */
354  SCIP_BANDIT* exp3, /**< bandit algorithm */
355  int action /**< index of the requested action */
356  )
357 {
358  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
359 
360  assert(banditdata->weightsum > 0.0);
361  assert(SCIPbanditGetNActions(exp3) > 0);
362 
363  return (1.0 - banditdata->gamma) * banditdata->weights[action] / banditdata->weightsum + banditdata->gamma / (SCIP_Real)SCIPbanditGetNActions(exp3);
364 }
365 
366 /** include virtual function table for Exp.3 bandit algorithms */
368  SCIP* scip /**< SCIP data structure */
369  )
370 {
371  SCIP_BANDITVTABLE* vtable;
372 
374  SCIPbanditFreeExp3, SCIPbanditSelectExp3, SCIPbanditUpdateExp3, SCIPbanditResetExp3) );
375  assert(vtable != NULL);
376 
377  return SCIP_OKAY;
378 }
#define NULL
Definition: def.h:253
public methods for memory management
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
Definition: bandit_exp3.c:134
void SCIPsetGammaExp3(SCIP_BANDIT *exp3, SCIP_Real gammaparam)
Definition: bandit_exp3.c:327
SCIP_RETCODE SCIPcreateBanditExp3(SCIP *scip, SCIP_BANDIT **exp3, SCIP_Real *priorities, SCIP_Real gammaparam, SCIP_Real beta, int nactions, unsigned int initseed)
Definition: bandit_exp3.c:301
internal methods for bandit algorithms
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:9640
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:53
void SCIPsetBetaExp3(SCIP_BANDIT *exp3, SCIP_Real beta)
Definition: bandit_exp3.c:340
SCIPInterval exp(const SCIPInterval &x)
#define BANDIT_NAME
Definition: bandit_exp3.c:32
SCIP_RETCODE SCIPbanditCreateExp3(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **exp3, SCIP_Real *priorities, SCIP_Real gammaparam, SCIP_Real beta, int nactions, unsigned int initseed)
Definition: bandit_exp3.c:271
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:180
SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
Definition: bandit_exp3.c:77
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
#define SCIPerrorMessage
Definition: pub_message.h:45
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip_mem.c:47
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:70
SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
Definition: bandit_exp3.c:209
#define SCIP_CALL(x)
Definition: def.h:365
SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
Definition: bandit_exp3.c:57
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:283
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:190
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:454
SCIP_RETCODE SCIPincludeBanditvtableExp3(SCIP *scip)
Definition: bandit_exp3.c:367
public data structures and miscellaneous methods
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:443
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:456
public methods for bandit algorithms
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip_mem.c:62
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:293
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:38
public methods for bandit algorithms
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:47
SCIP_Real SCIPgetProbabilityExp3(SCIP_BANDIT *exp3, int action)
Definition: bandit_exp3.c:353
#define NUMTOL
Definition: bandit_exp3.c:33
public methods for random numbers
internal methods for Exp.3 bandit algorithm
public methods for message output
#define SCIP_Real
Definition: def.h:164
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:441
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:427
#define SCIP_ALLOC(x)
Definition: def.h:376
#define EPSZ(x, eps)
Definition: def.h:194
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:32