Scippy

SCIP

Solving Constraint Integer Programs

bandit_exp3.c
Go to the documentation of this file.
1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (C) 2002-2022 Konrad-Zuse-Zentrum */
7 /* fuer Informationstechnik Berlin */
8 /* */
9 /* SCIP is distributed under the terms of the ZIB Academic License. */
10 /* */
11 /* You should have received a copy of the ZIB Academic License */
12 /* along with SCIP; see the file COPYING. If not visit scipopt.org. */
13 /* */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15 
16 /**@file bandit_exp3.c
17  * @ingroup OTHER_CFILES
18  * @brief methods for Exp.3 bandit selection
19  * @author Gregor Hendel
20  */
21 
22 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
23 
24 #include "scip/bandit.h"
25 #include "scip/bandit_exp3.h"
26 #include "scip/pub_bandit.h"
27 #include "scip/pub_message.h"
28 #include "scip/pub_misc.h"
29 #include "scip/scip_bandit.h"
30 #include "scip/scip_mem.h"
31 #include "scip/scip_randnumgen.h"
32 
33 #define BANDIT_NAME "exp3"
34 #define NUMTOL 1e-6
35 
36 /*
37  * Data structures
38  */
39 
40 /** implementation specific data of Exp.3 bandit algorithm */
41 struct SCIP_BanditData
42 {
43  SCIP_Real* weights; /**< exponential weight for each arm */
44  SCIP_Real weightsum; /**< the sum of all weights */
45  SCIP_Real gamma; /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
46  SCIP_Real beta; /**< gain offset between 0 and 1 at every observation */
47 };
48 
49 /*
50  * Local methods
51  */
52 
53 /*
54  * Callback methods of bandit algorithm
55  */
56 
57 /** callback to free bandit specific data structures */
58 SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
59 { /*lint --e{715}*/
60  SCIP_BANDITDATA* banditdata;
61  int nactions;
62  assert(bandit != NULL);
63 
64  banditdata = SCIPbanditGetData(bandit);
65  assert(banditdata != NULL);
66  nactions = SCIPbanditGetNActions(bandit);
67 
68  BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
69 
70  BMSfreeBlockMemory(blkmem, &banditdata);
71 
72  SCIPbanditSetData(bandit, NULL);
73 
74  return SCIP_OKAY;
75 }
76 
77 /** selection callback for bandit selector */
78 SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
79 { /*lint --e{715}*/
80  SCIP_BANDITDATA* banditdata;
81  SCIP_RANDNUMGEN* rng;
82  SCIP_Real randnr;
83  SCIP_Real psum;
84  SCIP_Real gammaoverk;
85  SCIP_Real oneminusgamma;
86  SCIP_Real* weights;
87  SCIP_Real weightsum;
88  int i;
89  int nactions;
90 
91  assert(bandit != NULL);
92  assert(selection != NULL);
93 
94  banditdata = SCIPbanditGetData(bandit);
95  assert(banditdata != NULL);
96  rng = SCIPbanditGetRandnumgen(bandit);
97  assert(rng != NULL);
98  nactions = SCIPbanditGetNActions(bandit);
99 
100  /* draw a random number between 0 and 1 */
101  randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
102 
103  /* initialize some local variables to speed up probability computations */
104  oneminusgamma = 1 - banditdata->gamma;
105  gammaoverk = banditdata->gamma / (SCIP_Real)nactions;
106  weightsum = banditdata->weightsum;
107  weights = banditdata->weights;
108  psum = 0.0;
109 
110  /* loop over probability distribution until rand is reached
111  * the loop terminates without looking at the last action,
112  * which is then selected automatically if the target probability
113  * is not reached earlier
114  */
115  for( i = 0; i < nactions - 1; ++i )
116  {
117  SCIP_Real prob;
118 
119  /* compute the probability for arm i as convex kombination of a uniform distribution and a weighted distribution */
120  prob = oneminusgamma * weights[i] / weightsum + gammaoverk;
121  psum += prob;
122 
123  /* break and select element if target probability is reached */
124  if( randnr <= psum )
125  break;
126  }
127 
128  /* select element i, which is the last action in case that the break statement hasn't been reached */
129  *selection = i;
130 
131  return SCIP_OKAY;
132 }
133 
134 /** update callback for bandit algorithm */
135 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
136 { /*lint --e{715}*/
137  SCIP_BANDITDATA* banditdata;
138  SCIP_Real eta;
139  SCIP_Real gainestim;
140  SCIP_Real beta;
141  SCIP_Real weightsum;
142  SCIP_Real newweightsum;
143  SCIP_Real* weights;
144  SCIP_Real oneminusgamma;
145  SCIP_Real gammaoverk;
146  int nactions;
147 
148  assert(bandit != NULL);
149 
150  banditdata = SCIPbanditGetData(bandit);
151  assert(banditdata != NULL);
152  nactions = SCIPbanditGetNActions(bandit);
153 
154  assert(selection >= 0);
155  assert(selection < nactions);
156 
157  /* the learning rate eta */
158  eta = 1.0 / (SCIP_Real)nactions;
159 
160  beta = banditdata->beta;
161  oneminusgamma = 1.0 - banditdata->gamma;
162  gammaoverk = banditdata->gamma * eta;
163  weights = banditdata->weights;
164  weightsum = banditdata->weightsum;
165  newweightsum = weightsum;
166 
167  /* if beta is zero, only the observation for the current arm needs an update */
168  if( EPSZ(beta, NUMTOL) )
169  {
170  SCIP_Real probai;
171  probai = oneminusgamma * weights[selection] / weightsum + gammaoverk;
172 
173  assert(probai > 0.0);
174 
175  gainestim = score / probai;
176  newweightsum -= weights[selection];
177  weights[selection] *= exp(eta * gainestim);
178  newweightsum += weights[selection];
179  }
180  else
181  {
182  int j;
183  newweightsum = 0.0;
184 
185  /* loop over all items and update their weights based on the influence of the beta parameter */
186  for( j = 0; j < nactions; ++j )
187  {
188  SCIP_Real probaj;
189  probaj = oneminusgamma * weights[j] / weightsum + gammaoverk;
190 
191  assert(probaj > 0.0);
192 
193  /* consider the score only for the chosen arm i, use constant beta offset otherwise */
194  if( j == selection )
195  gainestim = (score + beta) / probaj;
196  else
197  gainestim = beta / probaj;
198 
199  weights[j] *= exp(eta * gainestim);
200  newweightsum += weights[j];
201  }
202  }
203 
204  banditdata->weightsum = newweightsum;
205 
206  return SCIP_OKAY;
207 }
208 
209 /** reset callback for bandit algorithm */
210 SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
211 { /*lint --e{715}*/
212  SCIP_BANDITDATA* banditdata;
213  SCIP_Real* weights;
214  int nactions;
215  int i;
216 
217  assert(bandit != NULL);
218 
219  banditdata = SCIPbanditGetData(bandit);
220  assert(banditdata != NULL);
221  nactions = SCIPbanditGetNActions(bandit);
222  weights = banditdata->weights;
223 
224  assert(nactions > 0);
225 
226  banditdata->weightsum = (1.0 + NUMTOL) * (SCIP_Real)nactions;
227 
228  /* in case of priorities, weights are normalized to sum up to nactions */
229  if( priorities != NULL )
230  {
231  SCIP_Real normalization;
232  SCIP_Real priosum;
233  priosum = 0.0;
234 
235  /* compute sum of priorities */
236  for( i = 0; i < nactions; ++i )
237  {
238  assert(priorities[i] >= 0);
239  priosum += priorities[i];
240  }
241 
242  /* if there are positive priorities, normalize the weights */
243  if( priosum > 0.0 )
244  {
245  normalization = nactions / priosum;
246  for( i = 0; i < nactions; ++i )
247  weights[i] = (priorities[i] * normalization) + NUMTOL;
248  }
249  else
250  {
251  /* use uniform distribution in case of all priorities being 0.0 */
252  for( i = 0; i < nactions; ++i )
253  weights[i] = 1.0 + NUMTOL;
254  }
255  }
256  else
257  {
258  /* use uniform distribution in case of unspecified priorities */
259  for( i = 0; i < nactions; ++i )
260  weights[i] = 1.0 + NUMTOL;
261  }
262 
263  return SCIP_OKAY;
264 }
265 
266 
267 /*
268  * bandit algorithm specific interface methods
269  */
270 
271 /** direct bandit creation method for the core where no SCIP pointer is available */
273  BMS_BLKMEM* blkmem, /**< block memory data structure */
274  BMS_BUFMEM* bufmem, /**< buffer memory */
275  SCIP_BANDITVTABLE* vtable, /**< virtual function table for callback functions of Exp.3 */
276  SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
277  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
278  SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
279  SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
280  int nactions, /**< the positive number of actions for this bandit algorithm */
281  unsigned int initseed /**< initial random seed */
282  )
283 {
284  SCIP_BANDITDATA* banditdata;
285 
286  SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
287  assert(banditdata != NULL);
288 
289  banditdata->gamma = gammaparam;
290  banditdata->beta = beta;
291  assert(gammaparam >= 0 && gammaparam <= 1);
292  assert(beta >= 0 && beta <= 1);
293 
294  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
295 
296  SCIP_CALL( SCIPbanditCreate(exp3, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
297 
298  return SCIP_OKAY;
299 }
300 
301 /** creates and resets an Exp.3 bandit algorithm using \p scip pointer */
303  SCIP* scip, /**< SCIP data structure */
304  SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
305  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
306  SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
307  SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
308  int nactions, /**< the positive number of actions for this bandit algorithm */
309  unsigned int initseed /**< initial seed for random number generation */
310  )
311 {
312  SCIP_BANDITVTABLE* vtable;
313 
314  vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
315  if( vtable == NULL )
316  {
317  SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
318  return SCIP_INVALIDDATA;
319  }
320 
321  SCIP_CALL( SCIPbanditCreateExp3(SCIPblkmem(scip), SCIPbuffer(scip), vtable, exp3,
322  priorities, gammaparam, beta, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
323 
324  return SCIP_OKAY;
325 }
326 
327 /** set gamma parameter of Exp.3 bandit algorithm to increase weight of uniform distribution */
329  SCIP_BANDIT* exp3, /**< bandit algorithm */
330  SCIP_Real gammaparam /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
331  )
332 {
333  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
334 
335  assert(gammaparam >= 0 && gammaparam <= 1);
336 
337  banditdata->gamma = gammaparam;
338 }
339 
340 /** set beta parameter of Exp.3 bandit algorithm to increase gain offset for actions that were not played */
342  SCIP_BANDIT* exp3, /**< bandit algorithm */
343  SCIP_Real beta /**< gain offset between 0 and 1 at every observation */
344  )
345 {
346  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
347 
348  assert(beta >= 0 && beta <= 1);
349 
350  banditdata->beta = beta;
351 }
352 
353 /** returns probability to play an action */
355  SCIP_BANDIT* exp3, /**< bandit algorithm */
356  int action /**< index of the requested action */
357  )
358 {
359  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
360 
361  assert(banditdata->weightsum > 0.0);
362  assert(SCIPbanditGetNActions(exp3) > 0);
363 
364  return (1.0 - banditdata->gamma) * banditdata->weights[action] / banditdata->weightsum + banditdata->gamma / (SCIP_Real)SCIPbanditGetNActions(exp3);
365 }
366 
367 /** include virtual function table for Exp.3 bandit algorithms */
369  SCIP* scip /**< SCIP data structure */
370  )
371 {
372  SCIP_BANDITVTABLE* vtable;
373 
375  SCIPbanditFreeExp3, SCIPbanditSelectExp3, SCIPbanditUpdateExp3, SCIPbanditResetExp3) );
376  assert(vtable != NULL);
377 
378  return SCIP_OKAY;
379 }
SCIP_RETCODE SCIPcreateBanditExp3(SCIP *scip, SCIP_BANDIT **exp3, SCIP_Real *priorities, SCIP_Real gammaparam, SCIP_Real beta, int nactions, unsigned int initseed)
Definition: bandit_exp3.c:302
public methods for memory management
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
Definition: bandit_exp3.c:135
void SCIPsetBetaExp3(SCIP_BANDIT *exp3, SCIP_Real beta)
Definition: bandit_exp3.c:341
internal methods for bandit algorithms
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:54
#define BANDIT_NAME
Definition: bandit_exp3.c:33
SCIP_RETCODE SCIPbanditCreateExp3(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **exp3, SCIP_Real *priorities, SCIP_Real gammaparam, SCIP_Real beta, int nactions, unsigned int initseed)
Definition: bandit_exp3.c:272
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:181
SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
Definition: bandit_exp3.c:78
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip_mem.c:63
#define SCIPerrorMessage
Definition: pub_message.h:55
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip_mem.c:48
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:71
#define NULL
Definition: lpi_spx1.cpp:155
SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
Definition: bandit_exp3.c:210
#define SCIP_CALL(x)
Definition: def.h:384
SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
Definition: bandit_exp3.c:58
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:191
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:458
SCIP_RETCODE SCIPincludeBanditvtableExp3(SCIP *scip)
Definition: bandit_exp3.c:368
public data structures and miscellaneous methods
SCIP_Real SCIPgetProbabilityExp3(SCIP_BANDIT *exp3, int action)
Definition: bandit_exp3.c:354
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:447
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:39
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:460
public methods for bandit algorithms
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:10025
public methods for bandit algorithms
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:47
#define NUMTOL
Definition: bandit_exp3.c:34
public methods for random numbers
internal methods for Exp.3 bandit algorithm
void SCIPsetGammaExp3(SCIP_BANDIT *exp3, SCIP_Real gammaparam)
Definition: bandit_exp3.c:328
public methods for message output
#define SCIP_Real
Definition: def.h:177
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:294
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:284
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:444
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:430
#define SCIP_ALLOC(x)
Definition: def.h:395
#define EPSZ(x, eps)
Definition: def.h:207
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:33