John Ho commited on
Commit
9137c51
·
1 Parent(s): a792463

added do_sample to generate

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -233,7 +233,10 @@ def inference(
233
 
234
  # Inference
235
  generated_ids = model.generate(
236
- **inputs, max_new_tokens=max_tokens, temperature=float(temperature)
 
 
 
237
  )
238
  generated_ids_trimmed = [
239
  out_ids[len(in_ids) :]
@@ -255,7 +258,10 @@ def inference(
255
  ).to("cuda", dtype=DTYPE)
256
 
257
  output = model.generate(
258
- **inputs, max_new_tokens=max_tokens, temperature=float(temperature)
 
 
 
259
  )
260
  output_text = processor.decode(
261
  output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
 
233
 
234
  # Inference
235
  generated_ids = model.generate(
236
+ **inputs,
237
+ max_new_tokens=max_tokens,
238
+ temperature=float(temperature),
239
+ do_sample=temperature > 0.0,
240
  )
241
  generated_ids_trimmed = [
242
  out_ids[len(in_ids) :]
 
258
  ).to("cuda", dtype=DTYPE)
259
 
260
  output = model.generate(
261
+ **inputs,
262
+ max_new_tokens=max_tokens,
263
+ temperature=float(temperature),
264
+ do_sample=temperature > 0.0,
265
  )
266
  output_text = processor.decode(
267
  output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True