Scippy

SCIP

Solving Constraint Integer Programs

bandit_exp3ix.c
Go to the documentation of this file.
1/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2/* */
3/* This file is part of the program and library */
4/* SCIP --- Solving Constraint Integer Programs */
5/* */
6/* Copyright (c) 2002-2024 Zuse Institute Berlin (ZIB) */
7/* */
8/* Licensed under the Apache License, Version 2.0 (the "License"); */
9/* you may not use this file except in compliance with the License. */
10/* You may obtain a copy of the License at */
11/* */
12/* http://www.apache.org/licenses/LICENSE-2.0 */
13/* */
14/* Unless required by applicable law or agreed to in writing, software */
15/* distributed under the License is distributed on an "AS IS" BASIS, */
16/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17/* See the License for the specific language governing permissions and */
18/* limitations under the License. */
19/* */
20/* You should have received a copy of the Apache-2.0 license */
21/* along with SCIP; see the file LICENSE. If not visit scipopt.org. */
22/* */
23/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24
25/**@file bandit_exp3ix.c
26 * @ingroup OTHER_CFILES
27 * @brief methods for Exp.3-IX bandit selection
28 * @author Antonia Chmiela
29 */
30
31/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32
33#include "scip/bandit.h"
34#include "scip/bandit_exp3ix.h"
35#include "scip/pub_bandit.h"
36#include "scip/pub_message.h"
37#include "scip/pub_misc.h"
38#include "scip/scip_bandit.h"
39#include "scip/scip_mem.h"
41
42#define BANDIT_NAME "exp3ix"
43
44/*
45 * Data structures
46 */
47
48/** implementation specific data of Exp.3 bandit algorithm */
49struct SCIP_BanditData
50{
51 SCIP_Real* weights; /**< exponential weight for each arm */
52 SCIP_Real weightsum; /**< the sum of all weights */
53 int iter; /**< current iteration counter to compute parameters gamma_t and eta_t */
54};
55
56/*
57 * Local methods
58 */
59
60/*
61 * Callback methods of bandit algorithm
62 */
63
64/** callback to free bandit specific data structures */
65SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3IX)
66{ /*lint --e{715}*/
67 SCIP_BANDITDATA* banditdata;
68 int nactions;
69 assert(bandit != NULL);
70
71 banditdata = SCIPbanditGetData(bandit);
72 assert(banditdata != NULL);
73 nactions = SCIPbanditGetNActions(bandit);
74
75 BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
76
77 BMSfreeBlockMemory(blkmem, &banditdata);
78
79 SCIPbanditSetData(bandit, NULL);
80
81 return SCIP_OKAY;
82}
83
84/** selection callback for bandit selector */
85SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3IX)
86{ /*lint --e{715}*/
87 SCIP_BANDITDATA* banditdata;
88 SCIP_RANDNUMGEN* rng;
89 SCIP_Real* weights;
90 SCIP_Real weightsum;
91 int i;
92 int nactions;
93 SCIP_Real psum;
94 SCIP_Real randnr;
95
96 assert(bandit != NULL);
97 assert(selection != NULL);
98
99 banditdata = SCIPbanditGetData(bandit);
100 assert(banditdata != NULL);
101 rng = SCIPbanditGetRandnumgen(bandit);
102 assert(rng != NULL);
103 nactions = SCIPbanditGetNActions(bandit);
104
105 /* initialize some local variables to speed up probability computations */
106 weightsum = banditdata->weightsum;
107 weights = banditdata->weights;
108
109 /* draw a random number between 0 and 1 */
110 randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
111
112 /* loop over probability distribution until rand is reached
113 * the loop terminates without looking at the last action,
114 * which is then selected automatically if the target probability
115 * is not reached earlier
116 */
117 psum = 0.0;
118 for( i = 0; i < nactions - 1; ++i )
119 {
120 SCIP_Real prob;
121
122 /* compute the probability for arm i */
123 prob = weights[i] / weightsum;
124 psum += prob;
125
126 /* break and select element if target probability is reached */
127 if( randnr <= psum )
128 break;
129 }
130
131 /* select element i, which is the last action in case that the break statement hasn't been reached */
132 *selection = i;
133
134 return SCIP_OKAY;
135}
136
137/** compute gamma_t */
138static
140 int nactions, /**< the positive number of actions for this bandit algorithm */
141 int t /**< current iteration */
142 )
143{
144 return sqrt(log((SCIP_Real)nactions) / (4.0 * (SCIP_Real)t * (SCIP_Real)nactions));
145}
146
147/** update callback for bandit algorithm */
148SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3IX)
149{ /*lint --e{715}*/
150 SCIP_BANDITDATA* banditdata;
151 SCIP_Real etaparam;
152 SCIP_Real lossestim;
153 SCIP_Real prob;
154 SCIP_Real weightsum;
155 SCIP_Real newweightsum;
156 SCIP_Real* weights;
157 SCIP_Real gammaparam;
158 int nactions;
159
160 assert(bandit != NULL);
161
162 banditdata = SCIPbanditGetData(bandit);
163 assert(banditdata != NULL);
164 nactions = SCIPbanditGetNActions(bandit);
165
166 assert(selection >= 0);
167 assert(selection < nactions);
168
169 weights = banditdata->weights;
170 weightsum = banditdata->weightsum;
171 newweightsum = weightsum;
172 gammaparam = SCIPcomputeGamma(nactions, banditdata->iter);
173 etaparam = 2.0 * gammaparam;
174
175 /* probability of selection */
176 prob = weights[selection] / weightsum;
177
178 /* estimated loss */
179 lossestim = (1.0 - score) / (prob + gammaparam);
180 assert(lossestim >= 0);
181
182 /* update the observation for the current arm */
183 newweightsum -= weights[selection];
184 weights[selection] *= exp(-etaparam * lossestim);
185 newweightsum += weights[selection];
186
187 banditdata->weightsum = newweightsum;
188
189 /* increase iteration counter */
190 banditdata->iter += 1;
191
192 return SCIP_OKAY;
193}
194
195/** reset callback for bandit algorithm */
196SCIP_DECL_BANDITRESET(SCIPbanditResetExp3IX)
197{ /*lint --e{715}*/
198 SCIP_BANDITDATA* banditdata;
199 SCIP_Real* weights;
200 int nactions;
201 int i;
202
203 assert(bandit != NULL);
204
205 banditdata = SCIPbanditGetData(bandit);
206 assert(banditdata != NULL);
207 nactions = SCIPbanditGetNActions(bandit);
208 weights = banditdata->weights;
209
210 assert(nactions > 0);
211
212 /* initialize all weights with 1.0 */
213 for( i = 0; i < nactions; ++i )
214 weights[i] = 1.0;
215
216 banditdata->weightsum = (SCIP_Real)nactions;
217
218 /* set iteration counter to 1 */
219 banditdata->iter = 1;
220
221 return SCIP_OKAY;
222}
223
224
225/*
226 * bandit algorithm specific interface methods
227 */
228
229/** direct bandit creation method for the core where no SCIP pointer is available */
231 BMS_BLKMEM* blkmem, /**< block memory data structure */
232 BMS_BUFMEM* bufmem, /**< buffer memory */
233 SCIP_BANDITVTABLE* vtable, /**< virtual function table for callback functions of Exp.3-IX */
234 SCIP_BANDIT** exp3ix, /**< pointer to store bandit algorithm */
235 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
236 int nactions, /**< the positive number of actions for this bandit algorithm */
237 unsigned int initseed /**< initial random seed */
238 )
239{
240 SCIP_BANDITDATA* banditdata;
241
242 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
243 assert(banditdata != NULL);
244
245 banditdata->iter = 1;
246
247 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
248
249 SCIP_CALL( SCIPbanditCreate(exp3ix, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
250
251 return SCIP_OKAY;
252}
253
254/** creates and resets an Exp.3-IX bandit algorithm using \p scip pointer */
256 SCIP* scip, /**< SCIP data structure */
257 SCIP_BANDIT** exp3ix, /**< pointer to store bandit algorithm */
258 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
259 int nactions, /**< the positive number of actions for this bandit algorithm */
260 unsigned int initseed /**< initial seed for random number generation */
261 )
262{
263 SCIP_BANDITVTABLE* vtable;
264
266 if( vtable == NULL )
267 {
268 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
269 return SCIP_INVALIDDATA;
270 }
271
273 priorities, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
274
275 return SCIP_OKAY;
276}
277
278/** returns probability to play an action */
280 SCIP_BANDIT* exp3ix, /**< bandit algorithm */
281 int action /**< index of the requested action */
282 )
283{
284 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3ix);
285
286 assert(banditdata->weightsum > 0.0);
287 assert(SCIPbanditGetNActions(exp3ix) > 0);
288
289 return banditdata->weights[action] / banditdata->weightsum;
290}
291
292/** include virtual function table for Exp.3-IX bandit algorithms */
294 SCIP* scip /**< SCIP data structure */
295 )
296{
297 SCIP_BANDITVTABLE* vtable;
298
300 SCIPbanditFreeExp3IX, SCIPbanditSelectExp3IX, SCIPbanditUpdateExp3IX, SCIPbanditResetExp3IX) );
301 assert(vtable != NULL);
302
303 return SCIP_OKAY;
304}
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:200
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:42
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:190
internal methods for bandit algorithms
static SCIP_Real SCIPcomputeGamma(int nactions, int t)
SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3IX)
Definition: bandit_exp3ix.c:65
SCIP_RETCODE SCIPincludeBanditvtableExp3IX(SCIP *scip)
SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3IX)
Definition: bandit_exp3ix.c:85
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3IX)
#define BANDIT_NAME
Definition: bandit_exp3ix.c:42
SCIP_RETCODE SCIPbanditCreateExp3IX(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **exp3ix, SCIP_Real *priorities, int nactions, unsigned int initseed)
SCIP_DECL_BANDITRESET(SCIPbanditResetExp3IX)
internal methods for Exp.3-IX bandit algorithm
#define NULL
Definition: def.h:267
#define SCIP_ALLOC(x)
Definition: def.h:385
#define SCIP_Real
Definition: def.h:173
#define SCIP_CALL(x)
Definition: def.h:374
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:303
SCIP_Real SCIPgetProbabilityExp3IX(SCIP_BANDIT *exp3ix, int action)
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:293
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:80
SCIP_RETCODE SCIPcreateBanditExp3IX(SCIP *scip, SCIP_BANDIT **exp3ix, SCIP_Real *priorities, int nactions, unsigned int initseed)
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:48
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip_mem.c:72
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:10130
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:465
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:451
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:454
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:467
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:437
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip_mem.c:57
public methods for bandit algorithms
public methods for message output
#define SCIPerrorMessage
Definition: pub_message.h:64
public data structures and miscellaneous methods
public methods for bandit algorithms
public methods for memory management
public methods for random numbers
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:56
@ SCIP_INVALIDDATA
Definition: type_retcode.h:52
@ SCIP_OKAY
Definition: type_retcode.h:42
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:63