Scippy

SCIP

Solving Constraint Integer Programs

bandit_exp3.c
Go to the documentation of this file.
1/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2/* */
3/* This file is part of the program and library */
4/* SCIP --- Solving Constraint Integer Programs */
5/* */
6/* Copyright (c) 2002-2024 Zuse Institute Berlin (ZIB) */
7/* */
8/* Licensed under the Apache License, Version 2.0 (the "License"); */
9/* you may not use this file except in compliance with the License. */
10/* You may obtain a copy of the License at */
11/* */
12/* http://www.apache.org/licenses/LICENSE-2.0 */
13/* */
14/* Unless required by applicable law or agreed to in writing, software */
15/* distributed under the License is distributed on an "AS IS" BASIS, */
16/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17/* See the License for the specific language governing permissions and */
18/* limitations under the License. */
19/* */
20/* You should have received a copy of the Apache-2.0 license */
21/* along with SCIP; see the file LICENSE. If not visit scipopt.org. */
22/* */
23/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24
25/**@file bandit_exp3.c
26 * @ingroup OTHER_CFILES
27 * @brief methods for Exp.3 bandit selection
28 * @author Gregor Hendel
29 */
30
31/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32
33#include "scip/bandit.h"
34#include "scip/bandit_exp3.h"
35#include "scip/pub_bandit.h"
36#include "scip/pub_message.h"
37#include "scip/pub_misc.h"
38#include "scip/scip_bandit.h"
39#include "scip/scip_mem.h"
41
42#define BANDIT_NAME "exp3"
43#define NUMTOL 1e-6
44
45/*
46 * Data structures
47 */
48
49/** implementation specific data of Exp.3 bandit algorithm */
50struct SCIP_BanditData
51{
52 SCIP_Real* weights; /**< exponential weight for each arm */
53 SCIP_Real weightsum; /**< the sum of all weights */
54 SCIP_Real gamma; /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
55 SCIP_Real beta; /**< gain offset between 0 and 1 at every observation */
56};
57
58/*
59 * Local methods
60 */
61
62/*
63 * Callback methods of bandit algorithm
64 */
65
66/** callback to free bandit specific data structures */
67SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
68{ /*lint --e{715}*/
69 SCIP_BANDITDATA* banditdata;
70 int nactions;
71 assert(bandit != NULL);
72
73 banditdata = SCIPbanditGetData(bandit);
74 assert(banditdata != NULL);
75 nactions = SCIPbanditGetNActions(bandit);
76
77 BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
78
79 BMSfreeBlockMemory(blkmem, &banditdata);
80
81 SCIPbanditSetData(bandit, NULL);
82
83 return SCIP_OKAY;
84}
85
86/** selection callback for bandit selector */
87SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
88{ /*lint --e{715}*/
89 SCIP_BANDITDATA* banditdata;
90 SCIP_RANDNUMGEN* rng;
91 SCIP_Real randnr;
92 SCIP_Real psum;
93 SCIP_Real gammaoverk;
94 SCIP_Real oneminusgamma;
95 SCIP_Real* weights;
96 SCIP_Real weightsum;
97 int i;
98 int nactions;
99
100 assert(bandit != NULL);
101 assert(selection != NULL);
102
103 banditdata = SCIPbanditGetData(bandit);
104 assert(banditdata != NULL);
105 rng = SCIPbanditGetRandnumgen(bandit);
106 assert(rng != NULL);
107 nactions = SCIPbanditGetNActions(bandit);
108
109 /* draw a random number between 0 and 1 */
110 randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
111
112 /* initialize some local variables to speed up probability computations */
113 oneminusgamma = 1 - banditdata->gamma;
114 gammaoverk = banditdata->gamma / (SCIP_Real)nactions;
115 weightsum = banditdata->weightsum;
116 weights = banditdata->weights;
117 psum = 0.0;
118
119 /* loop over probability distribution until rand is reached
120 * the loop terminates without looking at the last action,
121 * which is then selected automatically if the target probability
122 * is not reached earlier
123 */
124 for( i = 0; i < nactions - 1; ++i )
125 {
126 SCIP_Real prob;
127
128 /* compute the probability for arm i as convex kombination of a uniform distribution and a weighted distribution */
129 prob = oneminusgamma * weights[i] / weightsum + gammaoverk;
130 psum += prob;
131
132 /* break and select element if target probability is reached */
133 if( randnr <= psum )
134 break;
135 }
136
137 /* select element i, which is the last action in case that the break statement hasn't been reached */
138 *selection = i;
139
140 return SCIP_OKAY;
141}
142
143/** update callback for bandit algorithm */
144SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
145{ /*lint --e{715}*/
146 SCIP_BANDITDATA* banditdata;
147 SCIP_Real eta;
148 SCIP_Real gainestim;
149 SCIP_Real beta;
150 SCIP_Real weightsum;
151 SCIP_Real newweightsum;
152 SCIP_Real* weights;
153 SCIP_Real oneminusgamma;
154 SCIP_Real gammaoverk;
155 int nactions;
156
157 assert(bandit != NULL);
158
159 banditdata = SCIPbanditGetData(bandit);
160 assert(banditdata != NULL);
161 nactions = SCIPbanditGetNActions(bandit);
162
163 assert(selection >= 0);
164 assert(selection < nactions);
165
166 /* the learning rate eta */
167 eta = 1.0 / (SCIP_Real)nactions;
168
169 beta = banditdata->beta;
170 oneminusgamma = 1.0 - banditdata->gamma;
171 gammaoverk = banditdata->gamma * eta;
172 weights = banditdata->weights;
173 weightsum = banditdata->weightsum;
174 newweightsum = weightsum;
175
176 /* if beta is zero, only the observation for the current arm needs an update */
177 if( EPSZ(beta, NUMTOL) )
178 {
179 SCIP_Real probai;
180 probai = oneminusgamma * weights[selection] / weightsum + gammaoverk;
181
182 assert(probai > 0.0);
183
184 gainestim = score / probai;
185 newweightsum -= weights[selection];
186 weights[selection] *= exp(eta * gainestim);
187 newweightsum += weights[selection];
188 }
189 else
190 {
191 int j;
192 newweightsum = 0.0;
193
194 /* loop over all items and update their weights based on the influence of the beta parameter */
195 for( j = 0; j < nactions; ++j )
196 {
197 SCIP_Real probaj;
198 probaj = oneminusgamma * weights[j] / weightsum + gammaoverk;
199
200 assert(probaj > 0.0);
201
202 /* consider the score only for the chosen arm i, use constant beta offset otherwise */
203 if( j == selection )
204 gainestim = (score + beta) / probaj;
205 else
206 gainestim = beta / probaj;
207
208 weights[j] *= exp(eta * gainestim);
209 newweightsum += weights[j];
210 }
211 }
212
213 banditdata->weightsum = newweightsum;
214
215 return SCIP_OKAY;
216}
217
218/** reset callback for bandit algorithm */
219SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
220{ /*lint --e{715}*/
221 SCIP_BANDITDATA* banditdata;
222 SCIP_Real* weights;
223 int nactions;
224 int i;
225
226 assert(bandit != NULL);
227
228 banditdata = SCIPbanditGetData(bandit);
229 assert(banditdata != NULL);
230 nactions = SCIPbanditGetNActions(bandit);
231 weights = banditdata->weights;
232
233 assert(nactions > 0);
234
235 banditdata->weightsum = (1.0 + NUMTOL) * (SCIP_Real)nactions;
236
237 /* in case of priorities, weights are normalized to sum up to nactions */
238 if( priorities != NULL )
239 {
240 SCIP_Real normalization;
241 SCIP_Real priosum;
242 priosum = 0.0;
243
244 /* compute sum of priorities */
245 for( i = 0; i < nactions; ++i )
246 {
247 assert(priorities[i] >= 0);
248 priosum += priorities[i];
249 }
250
251 /* if there are positive priorities, normalize the weights */
252 if( priosum > 0.0 )
253 {
254 normalization = nactions / priosum;
255 for( i = 0; i < nactions; ++i )
256 weights[i] = (priorities[i] * normalization) + NUMTOL;
257 }
258 else
259 {
260 /* use uniform distribution in case of all priorities being 0.0 */
261 for( i = 0; i < nactions; ++i )
262 weights[i] = 1.0 + NUMTOL;
263 }
264 }
265 else
266 {
267 /* use uniform distribution in case of unspecified priorities */
268 for( i = 0; i < nactions; ++i )
269 weights[i] = 1.0 + NUMTOL;
270 }
271
272 return SCIP_OKAY;
273}
274
275
276/*
277 * bandit algorithm specific interface methods
278 */
279
280/** direct bandit creation method for the core where no SCIP pointer is available */
282 BMS_BLKMEM* blkmem, /**< block memory data structure */
283 BMS_BUFMEM* bufmem, /**< buffer memory */
284 SCIP_BANDITVTABLE* vtable, /**< virtual function table for callback functions of Exp.3 */
285 SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
286 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
287 SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
288 SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
289 int nactions, /**< the positive number of actions for this bandit algorithm */
290 unsigned int initseed /**< initial random seed */
291 )
292{
293 SCIP_BANDITDATA* banditdata;
294
295 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
296 assert(banditdata != NULL);
297
298 banditdata->gamma = gammaparam;
299 banditdata->beta = beta;
300 assert(gammaparam >= 0 && gammaparam <= 1);
301 assert(beta >= 0 && beta <= 1);
302
303 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
304
305 SCIP_CALL( SCIPbanditCreate(exp3, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
306
307 return SCIP_OKAY;
308}
309
310/** creates and resets an Exp.3 bandit algorithm using \p scip pointer */
312 SCIP* scip, /**< SCIP data structure */
313 SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
314 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
315 SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
316 SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
317 int nactions, /**< the positive number of actions for this bandit algorithm */
318 unsigned int initseed /**< initial seed for random number generation */
319 )
320{
321 SCIP_BANDITVTABLE* vtable;
322
324 if( vtable == NULL )
325 {
326 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
327 return SCIP_INVALIDDATA;
328 }
329
331 priorities, gammaparam, beta, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
332
333 return SCIP_OKAY;
334}
335
336/** set gamma parameter of Exp.3 bandit algorithm to increase weight of uniform distribution */
338 SCIP_BANDIT* exp3, /**< bandit algorithm */
339 SCIP_Real gammaparam /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
340 )
341{
342 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
343
344 assert(gammaparam >= 0 && gammaparam <= 1);
345
346 banditdata->gamma = gammaparam;
347}
348
349/** set beta parameter of Exp.3 bandit algorithm to increase gain offset for actions that were not played */
351 SCIP_BANDIT* exp3, /**< bandit algorithm */
352 SCIP_Real beta /**< gain offset between 0 and 1 at every observation */
353 )
354{
355 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
356
357 assert(beta >= 0 && beta <= 1);
358
359 banditdata->beta = beta;
360}
361
362/** returns probability to play an action */
364 SCIP_BANDIT* exp3, /**< bandit algorithm */
365 int action /**< index of the requested action */
366 )
367{
368 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
369
370 assert(banditdata->weightsum > 0.0);
371 assert(SCIPbanditGetNActions(exp3) > 0);
372
373 return (1.0 - banditdata->gamma) * banditdata->weights[action] / banditdata->weightsum + banditdata->gamma / (SCIP_Real)SCIPbanditGetNActions(exp3);
374}
375
376/** include virtual function table for Exp.3 bandit algorithms */
378 SCIP* scip /**< SCIP data structure */
379 )
380{
381 SCIP_BANDITVTABLE* vtable;
382
384 SCIPbanditFreeExp3, SCIPbanditSelectExp3, SCIPbanditUpdateExp3, SCIPbanditResetExp3) );
385 assert(vtable != NULL);
386
387 return SCIP_OKAY;
388}
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:200
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:42
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:190
internal methods for bandit algorithms
SCIP_RETCODE SCIPbanditCreateExp3(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **exp3, SCIP_Real *priorities, SCIP_Real gammaparam, SCIP_Real beta, int nactions, unsigned int initseed)
Definition: bandit_exp3.c:281
#define NUMTOL
Definition: bandit_exp3.c:43
#define BANDIT_NAME
Definition: bandit_exp3.c:42
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
Definition: bandit_exp3.c:144
SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
Definition: bandit_exp3.c:219
SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
Definition: bandit_exp3.c:87
SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
Definition: bandit_exp3.c:67
SCIP_RETCODE SCIPincludeBanditvtableExp3(SCIP *scip)
Definition: bandit_exp3.c:377
internal methods for Exp.3 bandit algorithm
#define NULL
Definition: def.h:267
#define SCIP_ALLOC(x)
Definition: def.h:385
#define SCIP_Real
Definition: def.h:173
#define EPSZ(x, eps)
Definition: def.h:203
#define SCIP_CALL(x)
Definition: def.h:374
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:303
void SCIPsetGammaExp3(SCIP_BANDIT *exp3, SCIP_Real gammaparam)
Definition: bandit_exp3.c:337
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:293
SCIP_RETCODE SCIPcreateBanditExp3(SCIP *scip, SCIP_BANDIT **exp3, SCIP_Real *priorities, SCIP_Real gammaparam, SCIP_Real beta, int nactions, unsigned int initseed)
Definition: bandit_exp3.c:311
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:80
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:48
void SCIPsetBetaExp3(SCIP_BANDIT *exp3, SCIP_Real beta)
Definition: bandit_exp3.c:350
SCIP_Real SCIPgetProbabilityExp3(SCIP_BANDIT *exp3, int action)
Definition: bandit_exp3.c:363
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip_mem.c:72
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:10130
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:465
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:451
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:454
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:467
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:437
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip_mem.c:57
public methods for bandit algorithms
public methods for message output
#define SCIPerrorMessage
Definition: pub_message.h:64
public data structures and miscellaneous methods
public methods for bandit algorithms
public methods for memory management
public methods for random numbers
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:56
@ SCIP_INVALIDDATA
Definition: type_retcode.h:52
@ SCIP_OKAY
Definition: type_retcode.h:42
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:63